mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
[mlir][EDSC] Retire ValueHandle
The EDSC discussion [thread](https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879) points out that ValueHandle has become an unnecessary level of abstraction since MLIR switch from `Value *` to `Value` everywhere. This revision removes this level of indirection.
This commit is contained in:
@@ -531,9 +531,9 @@ llvm::Optional<SmallVector<AffineMap, 8>> batchmatmul::referenceIndexingMaps() {
|
||||
void batchmatmul::regionBuilder(ArrayRef<BlockArgument> args) {
|
||||
using namespace edsc;
|
||||
using namespace intrinsics;
|
||||
ValueHandle _0(args[0]), _1(args[1]), _2(args[2]);
|
||||
ValueHandle _4 = std_mulf(_0, _1);
|
||||
ValueHandle _5 = std_addf(_2, _4);
|
||||
Value _0(args[0]), _1(args[1]), _2(args[2]);
|
||||
Value _4 = std_mulf(_0, _1);
|
||||
Value _5 = std_addf(_2, _4);
|
||||
(linalg_yield(ValueRange{ _5 }));
|
||||
}
|
||||
```
|
||||
|
||||
@@ -13,30 +13,17 @@ case, in C++.
|
||||
supporting a simple declarative API with globally accessible builders. These
|
||||
declarative builders are available within the lifetime of a `ScopedContext`.
|
||||
|
||||
## ValueHandle and IndexHandle
|
||||
|
||||
`mlir::edsc::ValueHandle` and `mlir::edsc::IndexHandle` provide typed
|
||||
abstractions around an `mlir::Value`. These abstractions are "delayed", in the
|
||||
sense that they allow separating declaration from definition. They may capture
|
||||
IR snippets, as they are built, for programmatic manipulation. Intuitive
|
||||
operators are provided to allow concise and idiomatic expressions.
|
||||
|
||||
```c++
|
||||
ValueHandle zero = std_constant_index(0);
|
||||
IndexHandle i, j, k;
|
||||
```
|
||||
|
||||
## Intrinsics
|
||||
|
||||
`mlir::edsc::ValueBuilder` is a generic wrapper for the `mlir::Builder::create`
|
||||
method that operates on `ValueHandle` objects and return a single ValueHandle.
|
||||
For instructions that return no values or that return multiple values, the
|
||||
`mlir::edsc::InstructionBuilder` can be used. Named intrinsics are provided as
|
||||
`mlir::ValueBuilder` is a generic wrapper for the `mlir::OpBuilder::create`
|
||||
method that operates on `Value` objects and return a single Value. For
|
||||
instructions that return no values or that return multiple values, the
|
||||
`mlir::edsc::OperationBuilder` can be used. Named intrinsics are provided as
|
||||
syntactic sugar to further reduce boilerplate.
|
||||
|
||||
```c++
|
||||
using load = ValueBuilder<LoadOp>;
|
||||
using store = InstructionBuilder<StoreOp>;
|
||||
using store = OperationBuilder<StoreOp>;
|
||||
```
|
||||
|
||||
## LoopBuilder and AffineLoopNestBuilder
|
||||
@@ -46,14 +33,11 @@ concise and structured loop nests.
|
||||
|
||||
```c++
|
||||
ScopedContext scope(f.get());
|
||||
ValueHandle i(indexType),
|
||||
j(indexType),
|
||||
lb(f->getArgument(0)),
|
||||
ub(f->getArgument(1));
|
||||
ValueHandle f7(std_constant_float(llvm::APFloat(7.0f), f32Type)),
|
||||
f13(std_constant_float(llvm::APFloat(13.0f), f32Type)),
|
||||
i7(constant_int(7, 32)),
|
||||
i13(constant_int(13, 32));
|
||||
Value i, j, lb(f->getArgument(0)), ub(f->getArgument(1));
|
||||
Value f7(std_constant_float(llvm::APFloat(7.0f), f32Type)),
|
||||
f13(std_constant_float(llvm::APFloat(13.0f), f32Type)),
|
||||
i7(constant_int(7, 32)),
|
||||
i13(constant_int(13, 32));
|
||||
AffineLoopNestBuilder(&i, lb, ub, 3)([&]{
|
||||
lb * index_type(3) + ub;
|
||||
lb + index_type(3);
|
||||
@@ -84,11 +68,10 @@ def AddOp : Op<"x.add">,
|
||||
Arguments<(ins Tensor:$A, Tensor:$B)>,
|
||||
Results<(outs Tensor: $C)> {
|
||||
code referenceImplementation = [{
|
||||
auto ivs = makeIndexHandles(view_A.rank());
|
||||
auto pivs = makePIndexHandles(ivs);
|
||||
SmallVector<Value, 4> ivs(view_A.rank());
|
||||
IndexedValue A(arg_A), B(arg_B), C(arg_C);
|
||||
AffineLoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())(
|
||||
[&]{
|
||||
AffineLoopNestBuilder(
|
||||
ivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())([&]{
|
||||
C(ivs) = A(ivs) + B(ivs)
|
||||
});
|
||||
}];
|
||||
@@ -124,10 +107,4 @@ Similar APIs are provided to emit the lower-level `loop.for` op with
|
||||
`LoopNestBuilder`. See the `builder-api-test.cpp` test for more usage examples.
|
||||
|
||||
Since the implementation of declarative builders is in C++, it is also available
|
||||
to program the IR with an embedded-DSL flavor directly integrated in MLIR. We
|
||||
make use of these properties in the tutorial.
|
||||
|
||||
Spoiler: MLIR also provides Python bindings for these builders, and a
|
||||
full-fledged Python machine learning DSL with automatic differentiation
|
||||
targeting MLIR was built as an early research collaboration.
|
||||
|
||||
to program the IR with an embedded-DSL flavor directly integrated in MLIR.
|
||||
|
||||
@@ -23,12 +23,10 @@ namespace mlir {
|
||||
namespace edsc {
|
||||
|
||||
/// Constructs a new AffineForOp and captures the associated induction
|
||||
/// variable. A ValueHandle pointer is passed as the first argument and is the
|
||||
/// variable. A Value pointer is passed as the first argument and is the
|
||||
/// *only* way to capture the loop induction variable.
|
||||
LoopBuilder makeAffineLoopBuilder(ValueHandle *iv,
|
||||
ArrayRef<ValueHandle> lbHandles,
|
||||
ArrayRef<ValueHandle> ubHandles,
|
||||
int64_t step);
|
||||
LoopBuilder makeAffineLoopBuilder(Value *iv, ArrayRef<Value> lbs,
|
||||
ArrayRef<Value> ubs, int64_t step);
|
||||
|
||||
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
|
||||
/// explicitly writing all the loops in a nest. This simple functionality is
|
||||
@@ -58,10 +56,10 @@ public:
|
||||
/// This entry point accommodates the fact that AffineForOp implicitly uses
|
||||
/// multiple `lbs` and `ubs` with one single `iv` and `step` to encode `max`
|
||||
/// and and `min` constraints respectively.
|
||||
AffineLoopNestBuilder(ValueHandle *iv, ArrayRef<ValueHandle> lbs,
|
||||
ArrayRef<ValueHandle> ubs, int64_t step);
|
||||
AffineLoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
||||
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
|
||||
AffineLoopNestBuilder(Value *iv, ArrayRef<Value> lbs, ArrayRef<Value> ubs,
|
||||
int64_t step);
|
||||
AffineLoopNestBuilder(MutableArrayRef<Value> ivs, ArrayRef<Value> lbs,
|
||||
ArrayRef<Value> ubs, ArrayRef<int64_t> steps);
|
||||
|
||||
void operator()(function_ref<void(void)> fun = nullptr);
|
||||
|
||||
@@ -71,133 +69,134 @@ private:
|
||||
|
||||
namespace op {
|
||||
|
||||
ValueHandle operator+(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator-(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator*(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator/(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator%(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle floorDiv(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle ceilDiv(ValueHandle lhs, ValueHandle rhs);
|
||||
Value operator+(Value lhs, Value rhs);
|
||||
Value operator-(Value lhs, Value rhs);
|
||||
Value operator*(Value lhs, Value rhs);
|
||||
Value operator/(Value lhs, Value rhs);
|
||||
Value operator%(Value lhs, Value rhs);
|
||||
Value floorDiv(Value lhs, Value rhs);
|
||||
Value ceilDiv(Value lhs, Value rhs);
|
||||
|
||||
ValueHandle operator!(ValueHandle value);
|
||||
ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator||(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator^(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator==(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator<(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator>(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs);
|
||||
/// Logical operator overloadings.
|
||||
Value negate(Value value);
|
||||
Value operator&&(Value lhs, Value rhs);
|
||||
Value operator||(Value lhs, Value rhs);
|
||||
Value operator^(Value lhs, Value rhs);
|
||||
|
||||
/// Comparison operator overloadings.
|
||||
Value eq(Value lhs, Value rhs);
|
||||
Value ne(Value lhs, Value rhs);
|
||||
Value operator<(Value lhs, Value rhs);
|
||||
Value operator<=(Value lhs, Value rhs);
|
||||
Value operator>(Value lhs, Value rhs);
|
||||
Value operator>=(Value lhs, Value rhs);
|
||||
|
||||
} // namespace op
|
||||
|
||||
/// Arithmetic operator overloadings.
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator+(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator+(Value e) {
|
||||
using op::operator+;
|
||||
return static_cast<ValueHandle>(*this) + e;
|
||||
return static_cast<Value>(*this) + e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator-(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator-(Value e) {
|
||||
using op::operator-;
|
||||
return static_cast<ValueHandle>(*this) - e;
|
||||
return static_cast<Value>(*this) - e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator*(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator*(Value e) {
|
||||
using op::operator*;
|
||||
return static_cast<ValueHandle>(*this) * e;
|
||||
return static_cast<Value>(*this) * e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator/(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator/(Value e) {
|
||||
using op::operator/;
|
||||
return static_cast<ValueHandle>(*this) / e;
|
||||
return static_cast<Value>(*this) / e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator%(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator%(Value e) {
|
||||
using op::operator%;
|
||||
return static_cast<ValueHandle>(*this) % e;
|
||||
return static_cast<Value>(*this) % e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator^(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator^(Value e) {
|
||||
using op::operator^;
|
||||
return static_cast<ValueHandle>(*this) ^ e;
|
||||
return static_cast<Value>(*this) ^ e;
|
||||
}
|
||||
|
||||
/// Assignment-arithmetic operator overloadings.
|
||||
template <typename Load, typename Store>
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator+=(ValueHandle e) {
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator+=(Value e) {
|
||||
using op::operator+;
|
||||
return Store(*this + e, getBase(), {indices.begin(), indices.end()});
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator-=(ValueHandle e) {
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator-=(Value e) {
|
||||
using op::operator-;
|
||||
return Store(*this - e, getBase(), {indices.begin(), indices.end()});
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator*=(ValueHandle e) {
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator*=(Value e) {
|
||||
using op::operator*;
|
||||
return Store(*this * e, getBase(), {indices.begin(), indices.end()});
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator/=(ValueHandle e) {
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator/=(Value e) {
|
||||
using op::operator/;
|
||||
return Store(*this / e, getBase(), {indices.begin(), indices.end()});
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator%=(ValueHandle e) {
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator%=(Value e) {
|
||||
using op::operator%;
|
||||
return Store(*this % e, getBase(), {indices.begin(), indices.end()});
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator^=(ValueHandle e) {
|
||||
OperationHandle TemplatedIndexedValue<Load, Store>::operator^=(Value e) {
|
||||
using op::operator^;
|
||||
return Store(*this ^ e, getBase(), {indices.begin(), indices.end()});
|
||||
}
|
||||
|
||||
/// Logical operator overloadings.
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator&&(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator&&(Value e) {
|
||||
using op::operator&&;
|
||||
return static_cast<ValueHandle>(*this) && e;
|
||||
return static_cast<Value>(*this) && e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator||(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator||(Value e) {
|
||||
using op::operator||;
|
||||
return static_cast<ValueHandle>(*this) || e;
|
||||
return static_cast<Value>(*this) || e;
|
||||
}
|
||||
|
||||
/// Comparison operator overloadings.
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator==(ValueHandle e) {
|
||||
using op::operator==;
|
||||
return static_cast<ValueHandle>(*this) == e;
|
||||
Value TemplatedIndexedValue<Load, Store>::eq(Value e) {
|
||||
return eq(value, e);
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator!=(ValueHandle e) {
|
||||
using op::operator!=;
|
||||
return static_cast<ValueHandle>(*this) != e;
|
||||
Value TemplatedIndexedValue<Load, Store>::ne(Value e) {
|
||||
return ne(value, e);
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator<(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator<(Value e) {
|
||||
using op::operator<;
|
||||
return static_cast<ValueHandle>(*this) < e;
|
||||
return static_cast<Value>(*this) < e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator<=(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator<=(Value e) {
|
||||
using op::operator<=;
|
||||
return static_cast<ValueHandle>(*this) <= e;
|
||||
return static_cast<Value>(*this) <= e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator>(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator>(Value e) {
|
||||
using op::operator>;
|
||||
return static_cast<ValueHandle>(*this) > e;
|
||||
return static_cast<Value>(*this) > e;
|
||||
}
|
||||
template <typename Load, typename Store>
|
||||
ValueHandle TemplatedIndexedValue<Load, Store>::operator>=(ValueHandle e) {
|
||||
Value TemplatedIndexedValue<Load, Store>::operator>=(Value e) {
|
||||
using op::operator>=;
|
||||
return static_cast<ValueHandle>(*this) >= e;
|
||||
return static_cast<Value>(*this) >= e;
|
||||
}
|
||||
|
||||
} // namespace edsc
|
||||
|
||||
@@ -42,11 +42,10 @@ class ParallelLoopNestBuilder;
|
||||
class LoopRangeBuilder : public NestedBuilder {
|
||||
public:
|
||||
/// Constructs a new loop.for and captures the associated induction
|
||||
/// variable. A ValueHandle pointer is passed as the first argument and is the
|
||||
/// variable. A Value pointer is passed as the first argument and is the
|
||||
/// *only* way to capture the loop induction variable.
|
||||
LoopRangeBuilder(ValueHandle *iv, ValueHandle range);
|
||||
LoopRangeBuilder(ValueHandle *iv, Value range);
|
||||
LoopRangeBuilder(ValueHandle *iv, SubViewOp::Range range);
|
||||
LoopRangeBuilder(Value *iv, Value range);
|
||||
LoopRangeBuilder(Value *iv, SubViewOp::Range range);
|
||||
|
||||
LoopRangeBuilder(const LoopRangeBuilder &) = delete;
|
||||
LoopRangeBuilder(LoopRangeBuilder &&) = default;
|
||||
@@ -57,7 +56,7 @@ public:
|
||||
/// The only purpose of this operator is to serve as a sequence point so that
|
||||
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||
/// scoped within a LoopRangeBuilder.
|
||||
ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||
Value operator()(std::function<void(void)> fun = nullptr);
|
||||
};
|
||||
|
||||
/// Helper class to sugar building loop.for loop nests from ranges.
|
||||
@@ -65,13 +64,10 @@ public:
|
||||
/// directly. In the current implementation it produces loop.for operations.
|
||||
class LoopNestRangeBuilder {
|
||||
public:
|
||||
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
|
||||
ArrayRef<edsc::ValueHandle> ranges);
|
||||
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
|
||||
ArrayRef<Value> ranges);
|
||||
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
|
||||
LoopNestRangeBuilder(MutableArrayRef<Value> ivs, ArrayRef<Value> ranges);
|
||||
LoopNestRangeBuilder(MutableArrayRef<Value> ivs,
|
||||
ArrayRef<SubViewOp::Range> ranges);
|
||||
edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||
Value operator()(std::function<void(void)> fun = nullptr);
|
||||
|
||||
private:
|
||||
SmallVector<LoopRangeBuilder, 4> loops;
|
||||
@@ -81,7 +77,7 @@ private:
|
||||
/// ranges.
|
||||
template <typename LoopTy> class GenericLoopNestRangeBuilder {
|
||||
public:
|
||||
GenericLoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
|
||||
GenericLoopNestRangeBuilder(MutableArrayRef<Value> ivs,
|
||||
ArrayRef<Value> ranges);
|
||||
void operator()(std::function<void(void)> fun = nullptr) { (*builder)(fun); }
|
||||
|
||||
@@ -124,7 +120,6 @@ Operation *makeGenericLinalgOp(
|
||||
|
||||
namespace ops {
|
||||
using edsc::StructuredIndexed;
|
||||
using edsc::ValueHandle;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EDSC builders for linalg generic operations.
|
||||
@@ -160,7 +155,7 @@ void macRegionBuilder(ArrayRef<BlockArgument> args);
|
||||
/// with in-place semantics and parallelism.
|
||||
|
||||
/// Unary pointwise operation (with broadcast) entry point.
|
||||
using UnaryPointwiseOpBuilder = function_ref<Value(ValueHandle)>;
|
||||
using UnaryPointwiseOpBuilder = function_ref<Value(Value)>;
|
||||
Operation *linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp,
|
||||
StructuredIndexed I, StructuredIndexed O);
|
||||
|
||||
@@ -171,7 +166,7 @@ Operation *linalg_generic_pointwise_tanh(StructuredIndexed I,
|
||||
StructuredIndexed O);
|
||||
|
||||
/// Binary pointwise operation (with broadcast) entry point.
|
||||
using BinaryPointwiseOpBuilder = function_ref<Value(ValueHandle, ValueHandle)>;
|
||||
using BinaryPointwiseOpBuilder = function_ref<Value(Value, Value)>;
|
||||
Operation *linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp,
|
||||
StructuredIndexed I1, StructuredIndexed I2,
|
||||
StructuredIndexed O);
|
||||
@@ -202,7 +197,7 @@ using MatmulRegionBuilder = function_ref<void(ArrayRef<BlockArgument> args)>;
|
||||
/// | C(m, n) += A(m, k) * B(k, n)
|
||||
/// ```
|
||||
Operation *
|
||||
linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
|
||||
linalg_generic_matmul(Value vA, Value vB, Value vC,
|
||||
MatmulRegionBuilder regionBuilder = macRegionBuilder);
|
||||
|
||||
/// Build a linalg.generic, under the current ScopedContext, at the current
|
||||
@@ -214,7 +209,7 @@ linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
|
||||
/// ```
|
||||
/// and returns the tensor `C`.
|
||||
Operation *
|
||||
linalg_generic_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
|
||||
linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC,
|
||||
MatmulRegionBuilder regionBuilder = mulRegionBuilder);
|
||||
|
||||
/// Build a linalg.generic, under the current ScopedContext, at the current
|
||||
@@ -226,8 +221,7 @@ linalg_generic_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
|
||||
/// ```
|
||||
/// and returns the tensor `D`.
|
||||
Operation *
|
||||
linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
|
||||
RankedTensorType tD,
|
||||
linalg_generic_matmul(Value vA, Value vB, Value vC, RankedTensorType tD,
|
||||
MatmulRegionBuilder regionBuilder = macRegionBuilder);
|
||||
|
||||
template <typename Container>
|
||||
@@ -260,8 +254,8 @@ linalg_generic_matmul(Container values,
|
||||
/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
|
||||
///
|
||||
// TODO(ntv) Extend convolution rank with some template magic.
|
||||
Operation *linalg_generic_conv_nhwc(ValueHandle vI, ValueHandle vW,
|
||||
ValueHandle vO, ArrayRef<int> strides = {},
|
||||
Operation *linalg_generic_conv_nhwc(Value vI, Value vW, Value vO,
|
||||
ArrayRef<int> strides = {},
|
||||
ArrayRef<int> dilations = {});
|
||||
|
||||
template <typename Container>
|
||||
@@ -295,8 +289,7 @@ Operation *linalg_generic_conv_nhwc(Container values,
|
||||
/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
|
||||
///
|
||||
// TODO(ntv) Extend convolution rank with some template magic.
|
||||
Operation *linalg_generic_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW,
|
||||
ValueHandle vO,
|
||||
Operation *linalg_generic_dilated_conv_nhwc(Value vI, Value vW, Value vO,
|
||||
int depth_multiplier = 1,
|
||||
ArrayRef<int> strides = {},
|
||||
ArrayRef<int> dilations = {});
|
||||
|
||||
@@ -15,16 +15,52 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace edsc {
|
||||
namespace intrinsics {
|
||||
|
||||
template <typename Op, typename... Args>
|
||||
ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) {
|
||||
return folder ? ValueHandle(folder->create<Op>(ScopedContext::getBuilder(),
|
||||
ScopedContext::getLocation(),
|
||||
args...))
|
||||
: ValueHandle(ScopedContext::getBuilder().create<Op>(
|
||||
ScopedContext::getLocation(), args...));
|
||||
}
|
||||
template <typename Op>
|
||||
struct FoldedValueBuilder {
|
||||
// Builder-based
|
||||
template <typename... Args>
|
||||
FoldedValueBuilder(OperationFolder *folder, Args... args) {
|
||||
value = folder ? folder->create<Op>(ScopedContext::getBuilder(),
|
||||
ScopedContext::getLocation(), args...)
|
||||
: ScopedContext::getBuilder().create<Op>(
|
||||
ScopedContext::getLocation(), args...);
|
||||
}
|
||||
|
||||
operator Value() { return value; }
|
||||
Value value;
|
||||
};
|
||||
|
||||
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
|
||||
using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
|
||||
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
|
||||
using folded_std_constant = FoldedValueBuilder<ConstantOp>;
|
||||
using folded_std_dim = FoldedValueBuilder<DimOp>;
|
||||
using folded_std_muli = FoldedValueBuilder<MulIOp>;
|
||||
using folded_std_addi = FoldedValueBuilder<AddIOp>;
|
||||
using folded_std_addf = FoldedValueBuilder<AddFOp>;
|
||||
using folded_std_alloc = FoldedValueBuilder<AllocOp>;
|
||||
using folded_std_constant = FoldedValueBuilder<ConstantOp>;
|
||||
using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
|
||||
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
|
||||
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
|
||||
using folded_std_dim = FoldedValueBuilder<DimOp>;
|
||||
using folded_std_extract_element = FoldedValueBuilder<ExtractElementOp>;
|
||||
using folded_std_index_cast = FoldedValueBuilder<IndexCastOp>;
|
||||
using folded_std_muli = FoldedValueBuilder<MulIOp>;
|
||||
using folded_std_mulf = FoldedValueBuilder<MulFOp>;
|
||||
using folded_std_memref_cast = FoldedValueBuilder<MemRefCastOp>;
|
||||
using folded_std_select = FoldedValueBuilder<SelectOp>;
|
||||
using folded_std_load = FoldedValueBuilder<LoadOp>;
|
||||
using folded_std_subi = FoldedValueBuilder<SubIOp>;
|
||||
using folded_std_sub_view = FoldedValueBuilder<SubViewOp>;
|
||||
using folded_std_tanh = FoldedValueBuilder<TanhOp>;
|
||||
using folded_std_tensor_load = FoldedValueBuilder<TensorLoadOp>;
|
||||
using folded_std_view = FoldedValueBuilder<ViewOp>;
|
||||
using folded_std_zero_extendi = FoldedValueBuilder<ZeroExtendIOp>;
|
||||
using folded_std_sign_extendi = FoldedValueBuilder<SignExtendIOp>;
|
||||
} // namespace intrinsics
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -23,27 +23,30 @@ namespace mlir {
|
||||
namespace edsc {
|
||||
|
||||
/// Constructs a new loop::ParallelOp and captures the associated induction
|
||||
/// variables. An array of ValueHandle pointers is passed as the first
|
||||
/// variables. An array of Value pointers is passed as the first
|
||||
/// argument and is the *only* way to capture loop induction variables.
|
||||
LoopBuilder makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
|
||||
ArrayRef<ValueHandle> lbHandles,
|
||||
ArrayRef<ValueHandle> ubHandles,
|
||||
ArrayRef<ValueHandle> steps);
|
||||
LoopBuilder makeParallelLoopBuilder(MutableArrayRef<Value> ivs,
|
||||
ArrayRef<Value> lbs, ArrayRef<Value> ubs,
|
||||
ArrayRef<Value> steps);
|
||||
/// Constructs a new loop::ForOp and captures the associated induction
|
||||
/// variable. A ValueHandle pointer is passed as the first argument and is the
|
||||
/// variable. A Value pointer is passed as the first argument and is the
|
||||
/// *only* way to capture the loop induction variable.
|
||||
LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
|
||||
ValueHandle ubHandle, ValueHandle stepHandle,
|
||||
ArrayRef<ValueHandle *> iter_args_handles = {},
|
||||
ValueRange iter_args_init_values = {});
|
||||
LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step,
|
||||
MutableArrayRef<Value> iterArgsHandles,
|
||||
ValueRange iterArgsInitValues);
|
||||
LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step,
|
||||
MutableArrayRef<Value> iterArgsHandles,
|
||||
ValueRange iterArgsInitValues);
|
||||
inline LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step) {
|
||||
return makeLoopBuilder(iv, lb, ub, step, MutableArrayRef<Value>{}, {});
|
||||
}
|
||||
|
||||
/// Helper class to sugar building loop.parallel loop nests from lower/upper
|
||||
/// bounds and step sizes.
|
||||
class ParallelLoopNestBuilder {
|
||||
public:
|
||||
ParallelLoopNestBuilder(ArrayRef<ValueHandle *> ivs,
|
||||
ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
|
||||
ArrayRef<ValueHandle> steps);
|
||||
ParallelLoopNestBuilder(MutableArrayRef<Value> ivs, ArrayRef<Value> lbs,
|
||||
ArrayRef<Value> ubs, ArrayRef<Value> steps);
|
||||
|
||||
void operator()(function_ref<void(void)> fun = nullptr);
|
||||
|
||||
@@ -56,12 +59,12 @@ private:
|
||||
/// loop.for.
|
||||
class LoopNestBuilder {
|
||||
public:
|
||||
LoopNestBuilder(ValueHandle *iv, ValueHandle lb, ValueHandle ub,
|
||||
ValueHandle step,
|
||||
ArrayRef<ValueHandle *> iter_args_handles = {},
|
||||
ValueRange iter_args_init_values = {});
|
||||
LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
||||
ArrayRef<ValueHandle> ubs, ArrayRef<ValueHandle> steps);
|
||||
LoopNestBuilder(Value *iv, Value lb, Value ub, Value step);
|
||||
LoopNestBuilder(Value *iv, Value lb, Value ub, Value step,
|
||||
MutableArrayRef<Value> iterArgsHandles,
|
||||
ValueRange iterArgsInitValues);
|
||||
LoopNestBuilder(MutableArrayRef<Value> ivs, ArrayRef<Value> lbs,
|
||||
ArrayRef<Value> ubs, ArrayRef<Value> steps);
|
||||
Operation::result_range operator()(std::function<void(void)> fun = nullptr);
|
||||
|
||||
private:
|
||||
|
||||
@@ -20,27 +20,27 @@ namespace edsc {
|
||||
class BoundsCapture {
|
||||
public:
|
||||
unsigned rank() const { return lbs.size(); }
|
||||
ValueHandle lb(unsigned idx) { return lbs[idx]; }
|
||||
ValueHandle ub(unsigned idx) { return ubs[idx]; }
|
||||
Value lb(unsigned idx) { return lbs[idx]; }
|
||||
Value ub(unsigned idx) { return ubs[idx]; }
|
||||
int64_t step(unsigned idx) { return steps[idx]; }
|
||||
std::tuple<ValueHandle, ValueHandle, int64_t> range(unsigned idx) {
|
||||
std::tuple<Value, Value, int64_t> range(unsigned idx) {
|
||||
return std::make_tuple(lbs[idx], ubs[idx], steps[idx]);
|
||||
}
|
||||
void swapRanges(unsigned i, unsigned j) {
|
||||
if (i == j)
|
||||
return;
|
||||
lbs[i].swap(lbs[j]);
|
||||
ubs[i].swap(ubs[j]);
|
||||
std::swap(lbs[i], lbs[j]);
|
||||
std::swap(ubs[i], ubs[j]);
|
||||
std::swap(steps[i], steps[j]);
|
||||
}
|
||||
|
||||
ArrayRef<ValueHandle> getLbs() { return lbs; }
|
||||
ArrayRef<ValueHandle> getUbs() { return ubs; }
|
||||
ArrayRef<Value> getLbs() { return lbs; }
|
||||
ArrayRef<Value> getUbs() { return ubs; }
|
||||
ArrayRef<int64_t> getSteps() { return steps; }
|
||||
|
||||
protected:
|
||||
SmallVector<ValueHandle, 8> lbs;
|
||||
SmallVector<ValueHandle, 8> ubs;
|
||||
SmallVector<Value, 8> lbs;
|
||||
SmallVector<Value, 8> ubs;
|
||||
SmallVector<int64_t, 8> steps;
|
||||
};
|
||||
|
||||
@@ -58,7 +58,7 @@ public:
|
||||
unsigned fastestVarying() const { return rank() - 1; }
|
||||
|
||||
private:
|
||||
ValueHandle base;
|
||||
Value base;
|
||||
};
|
||||
|
||||
/// A VectorBoundsCapture represents the information required to step through a
|
||||
@@ -72,7 +72,7 @@ public:
|
||||
VectorBoundsCapture &operator=(const VectorBoundsCapture &) = default;
|
||||
|
||||
private:
|
||||
ValueHandle base;
|
||||
Value base;
|
||||
};
|
||||
|
||||
} // namespace edsc
|
||||
|
||||
@@ -14,40 +14,6 @@
|
||||
namespace mlir {
|
||||
namespace edsc {
|
||||
namespace intrinsics {
|
||||
namespace folded {
|
||||
/// Helper variadic abstraction to allow extending to any MLIR op without
|
||||
/// boilerplate or Tablegen.
|
||||
/// Arguably a builder is not a ValueHandle but in practice it is only used as
|
||||
/// an alias to a notional ValueHandle<Op>.
|
||||
/// Implementing it as a subclass allows it to compose all the way to Value.
|
||||
/// Without subclassing, implicit conversion to Value would fail when composing
|
||||
/// in patterns such as: `select(a, b, select(c, d, e))`.
|
||||
template <typename Op>
|
||||
struct ValueBuilder : public ValueHandle {
|
||||
/// Folder-based
|
||||
template <typename... Args>
|
||||
ValueBuilder(OperationFolder *folder, Args... args)
|
||||
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(args)...)) {}
|
||||
ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs)
|
||||
: ValueBuilder(ValueBuilder::create<Op>(folder, detail::unpack(vs))) {}
|
||||
template <typename... Args>
|
||||
ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs, Args... args)
|
||||
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(vs),
|
||||
detail::unpack(args)...)) {}
|
||||
template <typename T, typename... Args>
|
||||
ValueBuilder(OperationFolder *folder, T t, ArrayRef<ValueHandle> vs,
|
||||
Args... args)
|
||||
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(t),
|
||||
detail::unpack(vs),
|
||||
detail::unpack(args)...)) {}
|
||||
template <typename T1, typename T2, typename... Args>
|
||||
ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
|
||||
Args... args)
|
||||
: ValueHandle(ValueHandle::create<Op>(
|
||||
folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
|
||||
detail::unpack(args)...)) {}
|
||||
};
|
||||
} // namespace folded
|
||||
|
||||
using std_addf = ValueBuilder<AddFOp>;
|
||||
using std_alloc = ValueBuilder<AllocOp>;
|
||||
@@ -80,7 +46,7 @@ using std_sign_extendi = ValueBuilder<SignExtendIOp>;
|
||||
///
|
||||
/// Prerequisites:
|
||||
/// All Handles have already captured previously constructed IR objects.
|
||||
OperationHandle std_br(BlockHandle bh, ArrayRef<ValueHandle> operands);
|
||||
OperationHandle std_br(BlockHandle bh, ArrayRef<Value> operands);
|
||||
|
||||
/// Creates a new mlir::Block* and branches to it from the current block.
|
||||
/// Argument types are specified by `operands`.
|
||||
@@ -95,8 +61,9 @@ OperationHandle std_br(BlockHandle bh, ArrayRef<ValueHandle> operands);
|
||||
/// All `operands` have already captured an mlir::Value
|
||||
/// captures.size() == operands.size()
|
||||
/// captures and operands are pairwise of the same type.
|
||||
OperationHandle std_br(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
|
||||
ArrayRef<ValueHandle> operands);
|
||||
OperationHandle std_br(BlockHandle *bh, ArrayRef<Type> types,
|
||||
MutableArrayRef<Value> captures,
|
||||
ArrayRef<Value> operands);
|
||||
|
||||
/// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with
|
||||
/// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and
|
||||
@@ -104,10 +71,10 @@ OperationHandle std_br(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
|
||||
///
|
||||
/// Prerequisites:
|
||||
/// All Handles have captured previously constructed IR objects.
|
||||
OperationHandle std_cond_br(ValueHandle cond, BlockHandle trueBranch,
|
||||
ArrayRef<ValueHandle> trueOperands,
|
||||
OperationHandle std_cond_br(Value cond, BlockHandle trueBranch,
|
||||
ArrayRef<Value> trueOperands,
|
||||
BlockHandle falseBranch,
|
||||
ArrayRef<ValueHandle> falseOperands);
|
||||
ArrayRef<Value> falseOperands);
|
||||
|
||||
/// Eagerly creates new mlir::Block* with argument types specified by
|
||||
/// `trueOperands`/`falseOperands`.
|
||||
@@ -125,45 +92,17 @@ OperationHandle std_cond_br(ValueHandle cond, BlockHandle trueBranch,
|
||||
/// `falseCaptures`.size() == `falseOperands`.size()
|
||||
/// `trueCaptures` and `trueOperands` are pairwise of the same type
|
||||
/// `falseCaptures` and `falseOperands` are pairwise of the same type.
|
||||
OperationHandle std_cond_br(ValueHandle cond, BlockHandle *trueBranch,
|
||||
ArrayRef<ValueHandle *> trueCaptures,
|
||||
ArrayRef<ValueHandle> trueOperands,
|
||||
BlockHandle *falseBranch,
|
||||
ArrayRef<ValueHandle *> falseCaptures,
|
||||
ArrayRef<ValueHandle> falseOperands);
|
||||
OperationHandle std_cond_br(Value cond, BlockHandle *trueBranch,
|
||||
ArrayRef<Type> trueTypes,
|
||||
MutableArrayRef<Value> trueCaptures,
|
||||
ArrayRef<Value> trueOperands,
|
||||
BlockHandle *falseBranch, ArrayRef<Type> falseTypes,
|
||||
MutableArrayRef<Value> falseCaptures,
|
||||
ArrayRef<Value> falseOperands);
|
||||
|
||||
/// Provide an index notation around sdt_load and std_store.
|
||||
using StdIndexedValue =
|
||||
TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
|
||||
|
||||
using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>;
|
||||
using folded_std_constant_float = folded::ValueBuilder<ConstantFloatOp>;
|
||||
using folded_std_constant_int = folded::ValueBuilder<ConstantIntOp>;
|
||||
using folded_std_constant = folded::ValueBuilder<ConstantOp>;
|
||||
using folded_std_dim = folded::ValueBuilder<DimOp>;
|
||||
using folded_std_muli = folded::ValueBuilder<MulIOp>;
|
||||
using folded_std_addi = folded::ValueBuilder<AddIOp>;
|
||||
using folded_std_addf = folded::ValueBuilder<AddFOp>;
|
||||
using folded_std_alloc = folded::ValueBuilder<AllocOp>;
|
||||
using folded_std_constant = folded::ValueBuilder<ConstantOp>;
|
||||
using folded_std_constant_float = folded::ValueBuilder<ConstantFloatOp>;
|
||||
using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>;
|
||||
using folded_std_constant_int = folded::ValueBuilder<ConstantIntOp>;
|
||||
using folded_std_dim = folded::ValueBuilder<DimOp>;
|
||||
using folded_std_extract_element = folded::ValueBuilder<ExtractElementOp>;
|
||||
using folded_std_index_cast = folded::ValueBuilder<IndexCastOp>;
|
||||
using folded_std_muli = folded::ValueBuilder<MulIOp>;
|
||||
using folded_std_mulf = folded::ValueBuilder<MulFOp>;
|
||||
using folded_std_memref_cast = folded::ValueBuilder<MemRefCastOp>;
|
||||
using folded_std_select = folded::ValueBuilder<SelectOp>;
|
||||
using folded_std_load = folded::ValueBuilder<LoadOp>;
|
||||
using folded_std_subi = folded::ValueBuilder<SubIOp>;
|
||||
using folded_std_sub_view = folded::ValueBuilder<SubViewOp>;
|
||||
using folded_std_tanh = folded::ValueBuilder<TanhOp>;
|
||||
using folded_std_tensor_load = folded::ValueBuilder<TensorLoadOp>;
|
||||
using folded_std_view = folded::ValueBuilder<ViewOp>;
|
||||
using folded_std_zero_extendi = folded::ValueBuilder<ZeroExtendIOp>;
|
||||
using folded_std_sign_extendi = folded::ValueBuilder<SignExtendIOp>;
|
||||
} // namespace intrinsics
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
@@ -18,6 +18,7 @@ using vector_broadcast = ValueBuilder<vector::BroadcastOp>;
|
||||
using vector_contract = ValueBuilder<vector::ContractionOp>;
|
||||
using vector_matmul = ValueBuilder<vector::MatmulOp>;
|
||||
using vector_print = OperationBuilder<vector::PrintOp>;
|
||||
using vector_type_cast = ValueBuilder<vector::TypeCastOp>;
|
||||
|
||||
} // namespace intrinsics
|
||||
} // namespace edsc
|
||||
|
||||
@@ -24,9 +24,7 @@ class OperationFolder;
|
||||
|
||||
namespace edsc {
|
||||
class BlockHandle;
|
||||
class CapturableHandle;
|
||||
class NestedBuilder;
|
||||
class ValueHandle;
|
||||
|
||||
/// Helper class to transparently handle builder insertion points by RAII.
|
||||
/// As its name indicates, a ScopedContext is means to be used locally in a
|
||||
@@ -70,10 +68,23 @@ private:
|
||||
/// Defensively keeps track of the current NestedBuilder to ensure proper
|
||||
/// scoping usage.
|
||||
NestedBuilder *nestedBuilder;
|
||||
};
|
||||
|
||||
// TODO: Implement scoping of ValueHandles. To do this we need a proper data
|
||||
// structure to hold ValueHandle objects. We can emulate one but there should
|
||||
// already be something available in LLVM for this purpose.
|
||||
template <typename Op>
|
||||
struct ValueBuilder {
|
||||
// Builder-based
|
||||
template <typename... Args>
|
||||
ValueBuilder(Args... args) {
|
||||
Operation *op = ScopedContext::getBuilder()
|
||||
.create<Op>(ScopedContext::getLocation(), args...)
|
||||
.getOperation();
|
||||
if (op->getNumResults() != 1)
|
||||
llvm_unreachable("unsupported operation, use OperationBuilder instead");
|
||||
value = op->getResult(0);
|
||||
}
|
||||
|
||||
operator Value() { return value; }
|
||||
Value value;
|
||||
};
|
||||
|
||||
/// A NestedBuilder is a scoping abstraction to create an idiomatic syntax
|
||||
@@ -82,8 +93,7 @@ private:
|
||||
/// exists between object construction and method invocation on said object (in
|
||||
/// our case, the call to `operator()`).
|
||||
/// This ordering allows implementing an abstraction that decouples definition
|
||||
/// from declaration (in a PL sense) on placeholders of type ValueHandle and
|
||||
/// BlockHandle.
|
||||
/// from declaration (in a PL sense) on placeholders.
|
||||
class NestedBuilder {
|
||||
protected:
|
||||
NestedBuilder() = default;
|
||||
@@ -158,19 +168,17 @@ public:
|
||||
private:
|
||||
LoopBuilder() = default;
|
||||
|
||||
friend LoopBuilder makeAffineLoopBuilder(ValueHandle *iv,
|
||||
ArrayRef<ValueHandle> lbHandles,
|
||||
ArrayRef<ValueHandle> ubHandles,
|
||||
friend LoopBuilder makeAffineLoopBuilder(Value *iv, ArrayRef<Value> lbHandles,
|
||||
ArrayRef<Value> ubHandles,
|
||||
int64_t step);
|
||||
friend LoopBuilder makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
|
||||
ArrayRef<ValueHandle> lbHandles,
|
||||
ArrayRef<ValueHandle> ubHandles,
|
||||
ArrayRef<ValueHandle> steps);
|
||||
friend LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
|
||||
ValueHandle ubHandle,
|
||||
ValueHandle stepHandle,
|
||||
ArrayRef<ValueHandle *> iter_args_handles,
|
||||
ValueRange iter_args_init_values);
|
||||
friend LoopBuilder makeParallelLoopBuilder(MutableArrayRef<Value> ivs,
|
||||
ArrayRef<Value> lbHandles,
|
||||
ArrayRef<Value> ubHandles,
|
||||
ArrayRef<Value> steps);
|
||||
friend LoopBuilder makeLoopBuilder(Value *iv, Value lbHandle, Value ubHandle,
|
||||
Value stepHandle,
|
||||
MutableArrayRef<Value> iterArgsHandles,
|
||||
ValueRange iterArgsInitValues);
|
||||
Operation *op;
|
||||
};
|
||||
|
||||
@@ -194,9 +202,11 @@ public:
|
||||
/// Enters the new mlir::Block* and sets the insertion point to its end.
|
||||
///
|
||||
/// Prerequisites:
|
||||
/// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are
|
||||
/// The Value `args` are typed delayed Values; i.e. they are
|
||||
/// not yet bound to mlir::Value.
|
||||
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
|
||||
BlockBuilder(BlockHandle *bh) : BlockBuilder(bh, {}, {}) {}
|
||||
BlockBuilder(BlockHandle *bh, ArrayRef<Type> types,
|
||||
MutableArrayRef<Value> args);
|
||||
|
||||
/// Constructs a new mlir::Block with argument types derived from `args` and
|
||||
/// appends it as the last block in the region.
|
||||
@@ -204,9 +214,10 @@ public:
|
||||
/// Enters the new mlir::Block* and sets the insertion point to its end.
|
||||
///
|
||||
/// Prerequisites:
|
||||
/// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are
|
||||
/// The Value `args` are typed delayed Values; i.e. they are
|
||||
/// not yet bound to mlir::Value.
|
||||
BlockBuilder(BlockHandle *bh, Region ®ion, ArrayRef<ValueHandle *> args);
|
||||
BlockBuilder(BlockHandle *bh, Region ®ion, ArrayRef<Type> types,
|
||||
MutableArrayRef<Value> args);
|
||||
|
||||
/// The only purpose of this operator is to serve as a sequence point so that
|
||||
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||
@@ -218,120 +229,18 @@ private:
|
||||
BlockBuilder &operator=(BlockBuilder &other) = delete;
|
||||
};
|
||||
|
||||
/// Base class for ValueHandle, OperationHandle and BlockHandle.
|
||||
/// Base class for Value, OperationHandle and BlockHandle.
|
||||
/// Not meant to be used outside of these classes.
|
||||
class CapturableHandle {
|
||||
protected:
|
||||
CapturableHandle() = default;
|
||||
};
|
||||
|
||||
/// ValueHandle implements a (potentially "delayed") typed Value abstraction.
|
||||
/// ValueHandle should be captured by pointer but otherwise passed by Value
|
||||
/// everywhere.
|
||||
/// A ValueHandle can have 3 states:
|
||||
/// 1. null state (empty type and empty value), in which case it does not hold
|
||||
/// a value and must never hold a Value (now or in the future). This is
|
||||
/// used for MLIR operations with zero returns as well as the result of
|
||||
/// calling a NestedBuilder::operator(). In both cases the objective is to
|
||||
/// have an object that can be inserted in an ArrayRef<ValueHandle> to
|
||||
/// implement nesting;
|
||||
/// 2. delayed state (empty value), in which case it represents an eagerly
|
||||
/// typed "delayed" value that can be hold a Value in the future;
|
||||
/// 3. constructed state,in which case it holds a Value.
|
||||
///
|
||||
/// A ValueHandle is meant to capture a single Value and should be used for
|
||||
/// operations that have a single result. For convenience of use, we also
|
||||
/// include AffineForOp in this category although it does not return a value.
|
||||
/// In the case of AffineForOp, the captured Value is the loop induction
|
||||
/// variable.
|
||||
class ValueHandle : public CapturableHandle {
|
||||
public:
|
||||
/// A ValueHandle in a null state can never be captured;
|
||||
static ValueHandle null() { return ValueHandle(); }
|
||||
|
||||
/// A ValueHandle that is constructed from a Type represents a typed "delayed"
|
||||
/// Value. A delayed Value can only capture Values of the specified type.
|
||||
/// Such a delayed value represents the declaration (in the PL sense) of a
|
||||
/// placeholder for an mlir::Value that will be constructed and captured at
|
||||
/// some later point in the program.
|
||||
explicit ValueHandle(Type t) : t(t), v(nullptr) {}
|
||||
|
||||
/// A ValueHandle that is constructed from an mlir::Value is an "eager"
|
||||
/// Value. An eager Value represents both the declaration and the definition
|
||||
/// (in the PL sense) of a placeholder for an mlir::Value that has already
|
||||
/// been constructed in the past and that is captured "now" in the program.
|
||||
explicit ValueHandle(Value v) : t(v.getType()), v(v) {}
|
||||
|
||||
/// ValueHandle is a value type, use the default copy constructor.
|
||||
ValueHandle(const ValueHandle &other) = default;
|
||||
|
||||
/// ValueHandle is a value type, the assignment operator typechecks before
|
||||
/// assigning.
|
||||
ValueHandle &operator=(const ValueHandle &other);
|
||||
|
||||
/// Provide a swap operator.
|
||||
void swap(ValueHandle &other) {
|
||||
if (this == &other)
|
||||
return;
|
||||
std::swap(t, other.t);
|
||||
std::swap(v, other.v);
|
||||
}
|
||||
|
||||
/// Implicit conversion useful for automatic conversion to Container<Value>.
|
||||
operator Value() const { return getValue(); }
|
||||
operator Type() const { return getType(); }
|
||||
operator bool() const { return hasValue(); }
|
||||
|
||||
/// Generic mlir::Op create. This is the key to being extensible to the whole
|
||||
/// of MLIR without duplicating the type system or the op definitions.
|
||||
template <typename Op, typename... Args>
|
||||
static ValueHandle create(Args... args);
|
||||
|
||||
/// Generic mlir::Op create. This is the key to being extensible to the whole
|
||||
/// of MLIR without duplicating the type system or the op definitions.
|
||||
/// When non-null, the optional pointer `folder` is used to call into the
|
||||
/// `createAndFold` builder method. If `folder` is null, the regular `create`
|
||||
/// method is called.
|
||||
template <typename Op, typename... Args>
|
||||
static ValueHandle create(OperationFolder *folder, Args... args);
|
||||
|
||||
/// Generic create for a named operation producing a single value.
|
||||
static ValueHandle create(StringRef name, ArrayRef<ValueHandle> operands,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes = {});
|
||||
|
||||
bool hasValue() const { return v != nullptr; }
|
||||
Value getValue() const {
|
||||
assert(hasValue() && "Unexpected null value;");
|
||||
return v;
|
||||
}
|
||||
bool hasType() const { return t != Type(); }
|
||||
Type getType() const { return t; }
|
||||
|
||||
Operation *getOperation() const {
|
||||
if (!v)
|
||||
return nullptr;
|
||||
return v.getDefiningOp();
|
||||
}
|
||||
|
||||
// Return a vector of fresh ValueHandles that have not captured.
|
||||
static SmallVector<ValueHandle, 8> makeIndexHandles(unsigned count) {
|
||||
auto indexType = IndexType::get(ScopedContext::getContext());
|
||||
return SmallVector<ValueHandle, 8>(count, ValueHandle(indexType));
|
||||
}
|
||||
|
||||
protected:
|
||||
ValueHandle() : t(), v(nullptr) {}
|
||||
|
||||
Type t;
|
||||
Value v;
|
||||
};
|
||||
|
||||
/// An OperationHandle can be used in lieu of ValueHandle to capture the
|
||||
/// An OperationHandle can be used in lieu of Value to capture the
|
||||
/// operation in cases when one does not care about, or cannot extract, a
|
||||
/// unique Value from the operation.
|
||||
/// This can be used for capturing zero result operations as well as
|
||||
/// multi-result operations that are not supported by ValueHandle.
|
||||
/// multi-result operations that are not supported by Value.
|
||||
/// We do not distinguish further between zero and multi-result operations at
|
||||
/// this time.
|
||||
struct OperationHandle : public CapturableHandle {
|
||||
@@ -349,7 +258,7 @@ struct OperationHandle : public CapturableHandle {
|
||||
static Op createOp(Args... args);
|
||||
|
||||
/// Generic create for a named operation.
|
||||
static OperationHandle create(StringRef name, ArrayRef<ValueHandle> operands,
|
||||
static OperationHandle create(StringRef name, ArrayRef<Value> operands,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes = {});
|
||||
|
||||
@@ -360,23 +269,6 @@ private:
|
||||
Operation *op;
|
||||
};
|
||||
|
||||
/// Simple wrapper to build a generic operation without successor blocks.
|
||||
template <typename HandleType>
|
||||
struct CustomOperation {
|
||||
CustomOperation(StringRef name) : name(name) {
|
||||
static_assert(std::is_same<HandleType, ValueHandle>() ||
|
||||
std::is_same<HandleType, OperationHandle>(),
|
||||
"Only CustomOperation<ValueHandle> or "
|
||||
"CustomOperation<OperationHandle> can be constructed.");
|
||||
}
|
||||
HandleType operator()(ArrayRef<ValueHandle> operands = {},
|
||||
ArrayRef<Type> resultTypes = {},
|
||||
ArrayRef<NamedAttribute> attributes = {}) {
|
||||
return HandleType::create(name, operands, resultTypes, attributes);
|
||||
}
|
||||
std::string name;
|
||||
};
|
||||
|
||||
/// A BlockHandle represents a (potentially "delayed") Block abstraction.
|
||||
/// This extra abstraction is necessary because an mlir::Block is not an
|
||||
/// mlir::Value.
|
||||
@@ -427,32 +319,45 @@ private:
|
||||
/// C(buffer_value_or_tensor_type);
|
||||
/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
|
||||
/// ```
|
||||
struct StructuredIndexed : public ValueHandle {
|
||||
StructuredIndexed(Type type) : ValueHandle(type) {}
|
||||
StructuredIndexed(Value value) : ValueHandle(value) {}
|
||||
StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {}
|
||||
struct StructuredIndexed {
|
||||
StructuredIndexed(Value v) : value(v) {}
|
||||
StructuredIndexed(Type t) : type(t) {}
|
||||
StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
|
||||
return this->hasValue() ? StructuredIndexed(this->getValue(), indexings)
|
||||
: StructuredIndexed(this->getType(), indexings);
|
||||
return value ? StructuredIndexed(value, indexings)
|
||||
: StructuredIndexed(type, indexings);
|
||||
}
|
||||
|
||||
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
|
||||
: ValueHandle(t), exprs(indexings.begin(), indexings.end()) {
|
||||
assert(t.isa<RankedTensorType>() && "RankedTensor expected");
|
||||
}
|
||||
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
|
||||
: ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
|
||||
: value(v), exprs(indexings.begin(), indexings.end()) {
|
||||
assert((v.getType().isa<MemRefType>() ||
|
||||
v.getType().isa<RankedTensorType>() ||
|
||||
v.getType().isa<VectorType>()) &&
|
||||
"MemRef, RankedTensor or Vector expected");
|
||||
}
|
||||
StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
|
||||
: ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}
|
||||
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
|
||||
: type(t), exprs(indexings.begin(), indexings.end()) {
|
||||
assert((t.isa<MemRefType>() || t.isa<RankedTensorType>() ||
|
||||
t.isa<VectorType>()) &&
|
||||
"MemRef, RankedTensor or Vector expected");
|
||||
}
|
||||
|
||||
ArrayRef<AffineExpr> getExprs() { return exprs; }
|
||||
bool hasValue() const { return value; }
|
||||
Value getValue() const {
|
||||
assert(value && "StructuredIndexed Value not set.");
|
||||
return value;
|
||||
}
|
||||
Type getType() const {
|
||||
assert((value || type) && "StructuredIndexed Value and Type not set.");
|
||||
return value ? value.getType() : type;
|
||||
}
|
||||
ArrayRef<AffineExpr> getExprs() const { return exprs; }
|
||||
operator Value() const { return getValue(); }
|
||||
operator Type() const { return getType(); }
|
||||
|
||||
private:
|
||||
// Only one of Value or type may be set.
|
||||
Type type;
|
||||
Value value;
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
};
|
||||
|
||||
@@ -472,179 +377,139 @@ Op OperationHandle::createOp(Args... args) {
|
||||
.getOperation());
|
||||
}
|
||||
|
||||
template <typename Op, typename... Args>
|
||||
ValueHandle ValueHandle::create(Args... args) {
|
||||
Operation *op = ScopedContext::getBuilder()
|
||||
.create<Op>(ScopedContext::getLocation(), args...)
|
||||
.getOperation();
|
||||
if (op->getNumResults() == 1)
|
||||
return ValueHandle(op->getResult(0));
|
||||
llvm_unreachable("unsupported operation, use an OperationHandle instead");
|
||||
}
|
||||
|
||||
/// Entry point to build multiple ValueHandle from a `Container` of Value or
|
||||
/// Type.
|
||||
template <typename Container>
|
||||
inline SmallVector<ValueHandle, 8> makeValueHandles(Container values) {
|
||||
SmallVector<ValueHandle, 8> res;
|
||||
res.reserve(values.size());
|
||||
for (auto v : values)
|
||||
res.push_back(ValueHandle(v));
|
||||
return res;
|
||||
}
|
||||
|
||||
/// A TemplatedIndexedValue brings an index notation over the template Load and
|
||||
/// Store parameters. Assigning to an IndexedValue emits an actual `Store`
|
||||
/// operation, while converting an IndexedValue to a ValueHandle emits an actual
|
||||
/// operation, while converting an IndexedValue to a Value emits an actual
|
||||
/// `Load` operation.
|
||||
template <typename Load, typename Store>
|
||||
class TemplatedIndexedValue {
|
||||
public:
|
||||
explicit TemplatedIndexedValue(Type t) : base(t) {}
|
||||
explicit TemplatedIndexedValue(Value v)
|
||||
: TemplatedIndexedValue(ValueHandle(v)) {}
|
||||
explicit TemplatedIndexedValue(ValueHandle v) : base(v) {}
|
||||
explicit TemplatedIndexedValue(Value v) : value(v) {}
|
||||
|
||||
TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default;
|
||||
|
||||
TemplatedIndexedValue operator()() { return *this; }
|
||||
/// Returns a new `TemplatedIndexedValue`.
|
||||
TemplatedIndexedValue operator()(ValueHandle index) {
|
||||
TemplatedIndexedValue res(base);
|
||||
TemplatedIndexedValue operator()(Value index) {
|
||||
TemplatedIndexedValue res(value);
|
||||
res.indices.push_back(index);
|
||||
return res;
|
||||
}
|
||||
template <typename... Args>
|
||||
TemplatedIndexedValue operator()(ValueHandle index, Args... indices) {
|
||||
return TemplatedIndexedValue(base, index).append(indices...);
|
||||
TemplatedIndexedValue operator()(Value index, Args... indices) {
|
||||
return TemplatedIndexedValue(value, index).append(indices...);
|
||||
}
|
||||
TemplatedIndexedValue operator()(ArrayRef<ValueHandle> indices) {
|
||||
return TemplatedIndexedValue(base, indices);
|
||||
TemplatedIndexedValue operator()(ArrayRef<Value> indices) {
|
||||
return TemplatedIndexedValue(value, indices);
|
||||
}
|
||||
|
||||
/// Emits a `store`.
|
||||
OperationHandle operator=(const TemplatedIndexedValue &rhs) {
|
||||
ValueHandle rrhs(rhs);
|
||||
return Store(rrhs, getBase(), {indices.begin(), indices.end()});
|
||||
}
|
||||
OperationHandle operator=(ValueHandle rhs) {
|
||||
return Store(rhs, getBase(), {indices.begin(), indices.end()});
|
||||
}
|
||||
|
||||
/// Emits a `load` when converting to a ValueHandle.
|
||||
operator ValueHandle() const {
|
||||
return Load(getBase(), {indices.begin(), indices.end()});
|
||||
return Store(rhs, value, indices);
|
||||
}
|
||||
OperationHandle operator=(Value rhs) { return Store(rhs, value, indices); }
|
||||
|
||||
/// Emits a `load` when converting to a Value.
|
||||
Value operator*(void)const {
|
||||
return Load(getBase(), {indices.begin(), indices.end()}).getValue();
|
||||
}
|
||||
operator Value() const { return Load(value, indices); }
|
||||
|
||||
ValueHandle getBase() const { return base; }
|
||||
Value getBase() const { return value; }
|
||||
|
||||
/// Arithmetic operator overloadings.
|
||||
ValueHandle operator+(ValueHandle e);
|
||||
ValueHandle operator-(ValueHandle e);
|
||||
ValueHandle operator*(ValueHandle e);
|
||||
ValueHandle operator/(ValueHandle e);
|
||||
ValueHandle operator%(ValueHandle e);
|
||||
ValueHandle operator^(ValueHandle e);
|
||||
ValueHandle operator+(TemplatedIndexedValue e) {
|
||||
return *this + static_cast<ValueHandle>(e);
|
||||
Value operator+(Value e);
|
||||
Value operator-(Value e);
|
||||
Value operator*(Value e);
|
||||
Value operator/(Value e);
|
||||
Value operator%(Value e);
|
||||
Value operator^(Value e);
|
||||
Value operator+(TemplatedIndexedValue e) {
|
||||
return *this + static_cast<Value>(e);
|
||||
}
|
||||
ValueHandle operator-(TemplatedIndexedValue e) {
|
||||
return *this - static_cast<ValueHandle>(e);
|
||||
Value operator-(TemplatedIndexedValue e) {
|
||||
return *this - static_cast<Value>(e);
|
||||
}
|
||||
ValueHandle operator*(TemplatedIndexedValue e) {
|
||||
return *this * static_cast<ValueHandle>(e);
|
||||
Value operator*(TemplatedIndexedValue e) {
|
||||
return *this * static_cast<Value>(e);
|
||||
}
|
||||
ValueHandle operator/(TemplatedIndexedValue e) {
|
||||
return *this / static_cast<ValueHandle>(e);
|
||||
Value operator/(TemplatedIndexedValue e) {
|
||||
return *this / static_cast<Value>(e);
|
||||
}
|
||||
ValueHandle operator%(TemplatedIndexedValue e) {
|
||||
return *this % static_cast<ValueHandle>(e);
|
||||
Value operator%(TemplatedIndexedValue e) {
|
||||
return *this % static_cast<Value>(e);
|
||||
}
|
||||
ValueHandle operator^(TemplatedIndexedValue e) {
|
||||
return *this ^ static_cast<ValueHandle>(e);
|
||||
Value operator^(TemplatedIndexedValue e) {
|
||||
return *this ^ static_cast<Value>(e);
|
||||
}
|
||||
|
||||
/// Assignment-arithmetic operator overloadings.
|
||||
OperationHandle operator+=(ValueHandle e);
|
||||
OperationHandle operator-=(ValueHandle e);
|
||||
OperationHandle operator*=(ValueHandle e);
|
||||
OperationHandle operator/=(ValueHandle e);
|
||||
OperationHandle operator%=(ValueHandle e);
|
||||
OperationHandle operator^=(ValueHandle e);
|
||||
OperationHandle operator+=(Value e);
|
||||
OperationHandle operator-=(Value e);
|
||||
OperationHandle operator*=(Value e);
|
||||
OperationHandle operator/=(Value e);
|
||||
OperationHandle operator%=(Value e);
|
||||
OperationHandle operator^=(Value e);
|
||||
OperationHandle operator+=(TemplatedIndexedValue e) {
|
||||
return this->operator+=(static_cast<ValueHandle>(e));
|
||||
return this->operator+=(static_cast<Value>(e));
|
||||
}
|
||||
OperationHandle operator-=(TemplatedIndexedValue e) {
|
||||
return this->operator-=(static_cast<ValueHandle>(e));
|
||||
return this->operator-=(static_cast<Value>(e));
|
||||
}
|
||||
OperationHandle operator*=(TemplatedIndexedValue e) {
|
||||
return this->operator*=(static_cast<ValueHandle>(e));
|
||||
return this->operator*=(static_cast<Value>(e));
|
||||
}
|
||||
OperationHandle operator/=(TemplatedIndexedValue e) {
|
||||
return this->operator/=(static_cast<ValueHandle>(e));
|
||||
return this->operator/=(static_cast<Value>(e));
|
||||
}
|
||||
OperationHandle operator%=(TemplatedIndexedValue e) {
|
||||
return this->operator%=(static_cast<ValueHandle>(e));
|
||||
return this->operator%=(static_cast<Value>(e));
|
||||
}
|
||||
OperationHandle operator^=(TemplatedIndexedValue e) {
|
||||
return this->operator^=(static_cast<ValueHandle>(e));
|
||||
return this->operator^=(static_cast<Value>(e));
|
||||
}
|
||||
|
||||
/// Logical operator overloadings.
|
||||
ValueHandle operator&&(ValueHandle e);
|
||||
ValueHandle operator||(ValueHandle e);
|
||||
ValueHandle operator&&(TemplatedIndexedValue e) {
|
||||
return *this && static_cast<ValueHandle>(e);
|
||||
Value operator&&(Value e);
|
||||
Value operator||(Value e);
|
||||
Value operator&&(TemplatedIndexedValue e) {
|
||||
return *this && static_cast<Value>(e);
|
||||
}
|
||||
ValueHandle operator||(TemplatedIndexedValue e) {
|
||||
return *this || static_cast<ValueHandle>(e);
|
||||
Value operator||(TemplatedIndexedValue e) {
|
||||
return *this || static_cast<Value>(e);
|
||||
}
|
||||
|
||||
/// Comparison operator overloadings.
|
||||
ValueHandle operator==(ValueHandle e);
|
||||
ValueHandle operator!=(ValueHandle e);
|
||||
ValueHandle operator<(ValueHandle e);
|
||||
ValueHandle operator<=(ValueHandle e);
|
||||
ValueHandle operator>(ValueHandle e);
|
||||
ValueHandle operator>=(ValueHandle e);
|
||||
ValueHandle operator==(TemplatedIndexedValue e) {
|
||||
return *this == static_cast<ValueHandle>(e);
|
||||
Value eq(Value e);
|
||||
Value ne(Value e);
|
||||
Value operator<(Value e);
|
||||
Value operator<=(Value e);
|
||||
Value operator>(Value e);
|
||||
Value operator>=(Value e);
|
||||
Value operator<(TemplatedIndexedValue e) {
|
||||
return *this < static_cast<Value>(e);
|
||||
}
|
||||
ValueHandle operator!=(TemplatedIndexedValue e) {
|
||||
return *this != static_cast<ValueHandle>(e);
|
||||
Value operator<=(TemplatedIndexedValue e) {
|
||||
return *this <= static_cast<Value>(e);
|
||||
}
|
||||
ValueHandle operator<(TemplatedIndexedValue e) {
|
||||
return *this < static_cast<ValueHandle>(e);
|
||||
Value operator>(TemplatedIndexedValue e) {
|
||||
return *this > static_cast<Value>(e);
|
||||
}
|
||||
ValueHandle operator<=(TemplatedIndexedValue e) {
|
||||
return *this <= static_cast<ValueHandle>(e);
|
||||
}
|
||||
ValueHandle operator>(TemplatedIndexedValue e) {
|
||||
return *this > static_cast<ValueHandle>(e);
|
||||
}
|
||||
ValueHandle operator>=(TemplatedIndexedValue e) {
|
||||
return *this >= static_cast<ValueHandle>(e);
|
||||
Value operator>=(TemplatedIndexedValue e) {
|
||||
return *this >= static_cast<Value>(e);
|
||||
}
|
||||
|
||||
private:
|
||||
TemplatedIndexedValue(ValueHandle base, ArrayRef<ValueHandle> indices)
|
||||
: base(base), indices(indices.begin(), indices.end()) {}
|
||||
TemplatedIndexedValue(Value value, ArrayRef<Value> indices)
|
||||
: value(value), indices(indices.begin(), indices.end()) {}
|
||||
|
||||
TemplatedIndexedValue &append() { return *this; }
|
||||
|
||||
template <typename T, typename... Args>
|
||||
TemplatedIndexedValue &append(T index, Args... indices) {
|
||||
this->indices.push_back(static_cast<ValueHandle>(index));
|
||||
this->indices.push_back(static_cast<Value>(index));
|
||||
append(indices...);
|
||||
return *this;
|
||||
}
|
||||
ValueHandle base;
|
||||
SmallVector<ValueHandle, 8> indices;
|
||||
Value value;
|
||||
SmallVector<Value, 8> indices;
|
||||
};
|
||||
|
||||
} // namespace edsc
|
||||
|
||||
@@ -25,99 +25,27 @@ class Type;
|
||||
|
||||
namespace edsc {
|
||||
|
||||
/// Entry point to build multiple ValueHandle* from a mutable list `ivs`.
|
||||
inline SmallVector<ValueHandle *, 8>
|
||||
makeHandlePointers(MutableArrayRef<ValueHandle> ivs) {
|
||||
SmallVector<ValueHandle *, 8> pivs;
|
||||
pivs.reserve(ivs.size());
|
||||
for (auto &iv : ivs)
|
||||
pivs.push_back(&iv);
|
||||
return pivs;
|
||||
}
|
||||
|
||||
/// Provides a set of first class intrinsics.
|
||||
/// In the future, most of intrinsics related to Operation that don't contain
|
||||
/// other operations should be Tablegen'd.
|
||||
namespace intrinsics {
|
||||
namespace detail {
|
||||
/// Helper structure to be used with ValueBuilder / OperationBuilder.
|
||||
/// It serves the purpose of removing boilerplate specialization for the sole
|
||||
/// purpose of implicitly converting ArrayRef<ValueHandle> -> ArrayRef<Value>.
|
||||
class ValueHandleArray {
|
||||
public:
|
||||
ValueHandleArray(ArrayRef<ValueHandle> vals) {
|
||||
values.append(vals.begin(), vals.end());
|
||||
}
|
||||
operator ArrayRef<Value>() { return values; }
|
||||
|
||||
private:
|
||||
ValueHandleArray() = default;
|
||||
SmallVector<Value, 8> values;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline T unpack(T value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
inline detail::ValueHandleArray unpack(ArrayRef<ValueHandle> values) {
|
||||
return detail::ValueHandleArray(values);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Helper variadic abstraction to allow extending to any MLIR op without
|
||||
/// boilerplate or Tablegen.
|
||||
/// Arguably a builder is not a ValueHandle but in practice it is only used as
|
||||
/// an alias to a notional ValueHandle<Op>.
|
||||
/// Implementing it as a subclass allows it to compose all the way to Value.
|
||||
/// Without subclassing, implicit conversion to Value would fail when composing
|
||||
/// in patterns such as: `select(a, b, select(c, d, e))`.
|
||||
template <typename Op>
|
||||
struct ValueBuilder : public ValueHandle {
|
||||
// Builder-based
|
||||
template <typename... Args>
|
||||
ValueBuilder(Args... args)
|
||||
: ValueHandle(ValueHandle::create<Op>(detail::unpack(args)...)) {}
|
||||
ValueBuilder(ArrayRef<ValueHandle> vs)
|
||||
: ValueBuilder(ValueBuilder::create<Op>(detail::unpack(vs))) {}
|
||||
template <typename... Args>
|
||||
ValueBuilder(ArrayRef<ValueHandle> vs, Args... args)
|
||||
: ValueHandle(ValueHandle::create<Op>(detail::unpack(vs),
|
||||
detail::unpack(args)...)) {}
|
||||
template <typename T, typename... Args>
|
||||
ValueBuilder(T t, ArrayRef<ValueHandle> vs, Args... args)
|
||||
: ValueHandle(ValueHandle::create<Op>(
|
||||
detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {}
|
||||
template <typename T1, typename T2, typename... Args>
|
||||
ValueBuilder(T1 t1, T2 t2, ArrayRef<ValueHandle> vs, Args... args)
|
||||
: ValueHandle(ValueHandle::create<Op>(
|
||||
detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
|
||||
detail::unpack(args)...)) {}
|
||||
|
||||
ValueBuilder() : ValueHandle(ValueHandle::create<Op>()) {}
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
struct OperationBuilder : public OperationHandle {
|
||||
template <typename... Args>
|
||||
OperationBuilder(Args... args)
|
||||
: OperationHandle(OperationHandle::create<Op>(detail::unpack(args)...)) {}
|
||||
OperationBuilder(ArrayRef<ValueHandle> vs)
|
||||
: OperationHandle(OperationHandle::create<Op>(detail::unpack(vs))) {}
|
||||
: OperationHandle(OperationHandle::create<Op>(args...)) {}
|
||||
OperationBuilder(ArrayRef<Value> vs)
|
||||
: OperationHandle(OperationHandle::create<Op>(vs)) {}
|
||||
template <typename... Args>
|
||||
OperationBuilder(ArrayRef<ValueHandle> vs, Args... args)
|
||||
: OperationHandle(OperationHandle::create<Op>(detail::unpack(vs),
|
||||
detail::unpack(args)...)) {}
|
||||
OperationBuilder(ArrayRef<Value> vs, Args... args)
|
||||
: OperationHandle(OperationHandle::create<Op>(vs, args...)) {}
|
||||
template <typename T, typename... Args>
|
||||
OperationBuilder(T t, ArrayRef<ValueHandle> vs, Args... args)
|
||||
: OperationHandle(OperationHandle::create<Op>(
|
||||
detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {}
|
||||
OperationBuilder(T t, ArrayRef<Value> vs, Args... args)
|
||||
: OperationHandle(OperationHandle::create<Op>(t, vs, args...)) {}
|
||||
template <typename T1, typename T2, typename... Args>
|
||||
OperationBuilder(T1 t1, T2 t2, ArrayRef<ValueHandle> vs, Args... args)
|
||||
: OperationHandle(OperationHandle::create<Op>(
|
||||
detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
|
||||
detail::unpack(args)...)) {}
|
||||
OperationBuilder(T1 t1, T2 t2, ArrayRef<Value> vs, Args... args)
|
||||
: OperationHandle(OperationHandle::create<Op>(t1, t2, vs, args...)) {}
|
||||
OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {}
|
||||
};
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
@@ -38,9 +39,7 @@ using vector::TransferWriteOp;
|
||||
/// `pivs` and `vectorBoundsCapture` are swapped so that the invocation of
|
||||
/// LoopNestBuilder captures it in the innermost loop.
|
||||
template <typename TransferOpTy>
|
||||
static void coalesceCopy(TransferOpTy transfer,
|
||||
SmallVectorImpl<ValueHandle *> *pivs,
|
||||
VectorBoundsCapture *vectorBoundsCapture) {
|
||||
static int computeCoalescedIndex(TransferOpTy transfer) {
|
||||
// rank of the remote memory access, coalescing behavior occurs on the
|
||||
// innermost memory dimension.
|
||||
auto remoteRank = transfer.getMemRefType().getRank();
|
||||
@@ -62,24 +61,19 @@ static void coalesceCopy(TransferOpTy transfer,
|
||||
coalescedIdx = en.index();
|
||||
}
|
||||
}
|
||||
if (coalescedIdx >= 0) {
|
||||
std::swap(pivs->back(), (*pivs)[coalescedIdx]);
|
||||
vectorBoundsCapture->swapRanges(pivs->size() - 1, coalescedIdx);
|
||||
}
|
||||
return coalescedIdx;
|
||||
}
|
||||
|
||||
/// Emits remote memory accesses that are clipped to the boundaries of the
|
||||
/// MemRef.
|
||||
template <typename TransferOpTy>
|
||||
static SmallVector<ValueHandle, 8> clip(TransferOpTy transfer,
|
||||
MemRefBoundsCapture &bounds,
|
||||
ArrayRef<ValueHandle> ivs) {
|
||||
static SmallVector<Value, 8>
|
||||
clip(TransferOpTy transfer, MemRefBoundsCapture &bounds, ArrayRef<Value> ivs) {
|
||||
using namespace mlir::edsc;
|
||||
|
||||
ValueHandle zero(std_constant_index(0)), one(std_constant_index(1));
|
||||
SmallVector<ValueHandle, 8> memRefAccess(transfer.indices());
|
||||
auto clippedScalarAccessExprs =
|
||||
ValueHandle::makeIndexHandles(memRefAccess.size());
|
||||
Value zero(std_constant_index(0)), one(std_constant_index(1));
|
||||
SmallVector<Value, 8> memRefAccess(transfer.indices());
|
||||
SmallVector<Value, 8> clippedScalarAccessExprs(memRefAccess.size());
|
||||
// Indices accessing to remote memory are clipped and their expressions are
|
||||
// returned in clippedScalarAccessExprs.
|
||||
for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size();
|
||||
@@ -126,8 +120,6 @@ static SmallVector<ValueHandle, 8> clip(TransferOpTy transfer,
|
||||
|
||||
namespace {
|
||||
|
||||
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
|
||||
|
||||
/// Implements lowering of TransferReadOp and TransferWriteOp to a
|
||||
/// proper abstraction for the hardware.
|
||||
///
|
||||
@@ -257,31 +249,36 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
||||
StdIndexedValue remote(transfer.memref());
|
||||
MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
|
||||
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
|
||||
auto ivs = ValueHandle::makeIndexHandles(vectorBoundsCapture.rank());
|
||||
SmallVector<ValueHandle *, 8> pivs =
|
||||
makeHandlePointers(MutableArrayRef<ValueHandle>(ivs));
|
||||
coalesceCopy(transfer, &pivs, &vectorBoundsCapture);
|
||||
int coalescedIdx = computeCoalescedIndex(transfer);
|
||||
// Swap the vectorBoundsCapture which will reorder loop bounds.
|
||||
if (coalescedIdx >= 0)
|
||||
vectorBoundsCapture.swapRanges(vectorBoundsCapture.rank() - 1,
|
||||
coalescedIdx);
|
||||
|
||||
auto lbs = vectorBoundsCapture.getLbs();
|
||||
auto ubs = vectorBoundsCapture.getUbs();
|
||||
SmallVector<ValueHandle, 8> steps;
|
||||
SmallVector<Value, 8> steps;
|
||||
steps.reserve(vectorBoundsCapture.getSteps().size());
|
||||
for (auto step : vectorBoundsCapture.getSteps())
|
||||
steps.push_back(std_constant_index(step));
|
||||
|
||||
// 2. Emit alloc-copy-load-dealloc.
|
||||
ValueHandle tmp = std_alloc(tmpMemRefType(transfer));
|
||||
Value tmp = std_alloc(tmpMemRefType(transfer));
|
||||
StdIndexedValue local(tmp);
|
||||
ValueHandle vec = vector_type_cast(tmp);
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
|
||||
Value vec = vector_type_cast(tmp);
|
||||
SmallVector<Value, 8> ivs(lbs.size());
|
||||
LoopNestBuilder(ivs, lbs, ubs, steps)([&] {
|
||||
// Swap the ivs which will reorder memory accesses.
|
||||
if (coalescedIdx >= 0)
|
||||
std::swap(ivs.back(), ivs[coalescedIdx]);
|
||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||
local(ivs) = remote(clip(transfer, memRefBoundsCapture, ivs));
|
||||
});
|
||||
ValueHandle vectorValue = std_load(vec);
|
||||
Value vectorValue = std_load(vec);
|
||||
(std_dealloc(tmp)); // vexing parse
|
||||
|
||||
// 3. Propagate.
|
||||
rewriter.replaceOp(op, vectorValue.getValue());
|
||||
rewriter.replaceOp(op, vectorValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -314,26 +311,31 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
||||
ScopedContext scope(rewriter, transfer.getLoc());
|
||||
StdIndexedValue remote(transfer.memref());
|
||||
MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
|
||||
ValueHandle vectorValue(transfer.vector());
|
||||
Value vectorValue(transfer.vector());
|
||||
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
|
||||
auto ivs = ValueHandle::makeIndexHandles(vectorBoundsCapture.rank());
|
||||
SmallVector<ValueHandle *, 8> pivs =
|
||||
makeHandlePointers(MutableArrayRef<ValueHandle>(ivs));
|
||||
coalesceCopy(transfer, &pivs, &vectorBoundsCapture);
|
||||
int coalescedIdx = computeCoalescedIndex(transfer);
|
||||
// Swap the vectorBoundsCapture which will reorder loop bounds.
|
||||
if (coalescedIdx >= 0)
|
||||
vectorBoundsCapture.swapRanges(vectorBoundsCapture.rank() - 1,
|
||||
coalescedIdx);
|
||||
|
||||
auto lbs = vectorBoundsCapture.getLbs();
|
||||
auto ubs = vectorBoundsCapture.getUbs();
|
||||
SmallVector<ValueHandle, 8> steps;
|
||||
SmallVector<Value, 8> steps;
|
||||
steps.reserve(vectorBoundsCapture.getSteps().size());
|
||||
for (auto step : vectorBoundsCapture.getSteps())
|
||||
steps.push_back(std_constant_index(step));
|
||||
|
||||
// 2. Emit alloc-store-copy-dealloc.
|
||||
ValueHandle tmp = std_alloc(tmpMemRefType(transfer));
|
||||
Value tmp = std_alloc(tmpMemRefType(transfer));
|
||||
StdIndexedValue local(tmp);
|
||||
ValueHandle vec = vector_type_cast(tmp);
|
||||
Value vec = vector_type_cast(tmp);
|
||||
std_store(vectorValue, vec);
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
|
||||
SmallVector<Value, 8> ivs(lbs.size());
|
||||
LoopNestBuilder(ivs, lbs, ubs, steps)([&] {
|
||||
// Swap the ivs which will reorder memory accesses.
|
||||
if (coalescedIdx >= 0)
|
||||
std::swap(ivs.back(), ivs[coalescedIdx]);
|
||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||
remote(clip(transfer, memRefBoundsCapture, ivs)) = local(ivs);
|
||||
});
|
||||
|
||||
@@ -14,65 +14,61 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
|
||||
static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
|
||||
ArrayRef<ValueHandle> ubs,
|
||||
int64_t step) {
|
||||
static Optional<Value> emitStaticFor(ArrayRef<Value> lbs, ArrayRef<Value> ubs,
|
||||
int64_t step) {
|
||||
if (lbs.size() != 1 || ubs.size() != 1)
|
||||
return Optional<ValueHandle>();
|
||||
return Optional<Value>();
|
||||
|
||||
auto *lbDef = lbs.front().getValue().getDefiningOp();
|
||||
auto *ubDef = ubs.front().getValue().getDefiningOp();
|
||||
auto *lbDef = lbs.front().getDefiningOp();
|
||||
auto *ubDef = ubs.front().getDefiningOp();
|
||||
if (!lbDef || !ubDef)
|
||||
return Optional<ValueHandle>();
|
||||
return Optional<Value>();
|
||||
|
||||
auto lbConst = dyn_cast<ConstantIndexOp>(lbDef);
|
||||
auto ubConst = dyn_cast<ConstantIndexOp>(ubDef);
|
||||
if (!lbConst || !ubConst)
|
||||
return Optional<ValueHandle>();
|
||||
|
||||
return ValueHandle(ScopedContext::getBuilder()
|
||||
.create<AffineForOp>(ScopedContext::getLocation(),
|
||||
lbConst.getValue(),
|
||||
ubConst.getValue(), step)
|
||||
.getInductionVar());
|
||||
return Optional<Value>();
|
||||
return ScopedContext::getBuilder()
|
||||
.create<AffineForOp>(ScopedContext::getLocation(), lbConst.getValue(),
|
||||
ubConst.getValue(), step)
|
||||
.getInductionVar();
|
||||
}
|
||||
|
||||
LoopBuilder mlir::edsc::makeAffineLoopBuilder(ValueHandle *iv,
|
||||
ArrayRef<ValueHandle> lbHandles,
|
||||
ArrayRef<ValueHandle> ubHandles,
|
||||
LoopBuilder mlir::edsc::makeAffineLoopBuilder(Value *iv, ArrayRef<Value> lbs,
|
||||
ArrayRef<Value> ubs,
|
||||
int64_t step) {
|
||||
mlir::edsc::LoopBuilder result;
|
||||
if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) {
|
||||
*iv = staticFor.getValue();
|
||||
if (auto staticForIv = emitStaticFor(lbs, ubs, step)) {
|
||||
*iv = staticForIv.getValue();
|
||||
} else {
|
||||
SmallVector<Value, 4> lbs(lbHandles.begin(), lbHandles.end());
|
||||
SmallVector<Value, 4> ubs(ubHandles.begin(), ubHandles.end());
|
||||
auto b = ScopedContext::getBuilder();
|
||||
*iv = ValueHandle(
|
||||
b.create<AffineForOp>(ScopedContext::getLocation(), lbs,
|
||||
b.getMultiDimIdentityMap(lbs.size()), ubs,
|
||||
b.getMultiDimIdentityMap(ubs.size()), step)
|
||||
.getInductionVar());
|
||||
*iv =
|
||||
Value(b.create<AffineForOp>(ScopedContext::getLocation(), lbs,
|
||||
b.getMultiDimIdentityMap(lbs.size()), ubs,
|
||||
b.getMultiDimIdentityMap(ubs.size()), step)
|
||||
.getInductionVar());
|
||||
}
|
||||
auto *body = getForInductionVarOwner(iv->getValue()).getBody();
|
||||
|
||||
auto *body = getForInductionVarOwner(*iv).getBody();
|
||||
result.enter(body, /*prev=*/1);
|
||||
return result;
|
||||
}
|
||||
|
||||
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
|
||||
ValueHandle *iv, ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
|
||||
int64_t step) {
|
||||
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(Value *iv,
|
||||
ArrayRef<Value> lbs,
|
||||
ArrayRef<Value> ubs,
|
||||
int64_t step) {
|
||||
loops.emplace_back(makeAffineLoopBuilder(iv, lbs, ubs, step));
|
||||
}
|
||||
|
||||
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
|
||||
ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
||||
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps) {
|
||||
MutableArrayRef<Value> ivs, ArrayRef<Value> lbs, ArrayRef<Value> ubs,
|
||||
ArrayRef<int64_t> steps) {
|
||||
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
|
||||
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
|
||||
assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
|
||||
for (auto it : llvm::zip(ivs, lbs, ubs, steps))
|
||||
loops.emplace_back(makeAffineLoopBuilder(std::get<0>(it), std::get<1>(it),
|
||||
loops.emplace_back(makeAffineLoopBuilder(&std::get<0>(it), std::get<1>(it),
|
||||
std::get<2>(it), std::get<3>(it)));
|
||||
}
|
||||
|
||||
@@ -89,11 +85,6 @@ void mlir::edsc::AffineLoopNestBuilder::operator()(
|
||||
(*lit)();
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) {
|
||||
return ValueHandle::create<Op>(lhs.getValue(), rhs.getValue());
|
||||
}
|
||||
|
||||
static std::pair<AffineExpr, Value>
|
||||
categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
|
||||
unsigned &numSymbols) {
|
||||
@@ -111,115 +102,109 @@ categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
|
||||
return std::make_pair(d, resultVal);
|
||||
}
|
||||
|
||||
static ValueHandle createBinaryIndexHandle(
|
||||
ValueHandle lhs, ValueHandle rhs,
|
||||
static Value createBinaryIndexHandle(
|
||||
Value lhs, Value rhs,
|
||||
function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
|
||||
MLIRContext *context = ScopedContext::getContext();
|
||||
unsigned numDims = 0, numSymbols = 0;
|
||||
AffineExpr d0, d1;
|
||||
Value v0, v1;
|
||||
std::tie(d0, v0) =
|
||||
categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols);
|
||||
categorizeValueByAffineType(context, lhs, numDims, numSymbols);
|
||||
std::tie(d1, v1) =
|
||||
categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols);
|
||||
categorizeValueByAffineType(context, rhs, numDims, numSymbols);
|
||||
SmallVector<Value, 2> operands;
|
||||
if (v0) {
|
||||
if (v0)
|
||||
operands.push_back(v0);
|
||||
}
|
||||
if (v1) {
|
||||
if (v1)
|
||||
operands.push_back(v1);
|
||||
}
|
||||
auto map = AffineMap::get(numDims, numSymbols, affCombiner(d0, d1));
|
||||
|
||||
// TODO: createOrFold when available.
|
||||
Operation *op =
|
||||
makeComposedAffineApply(ScopedContext::getBuilder(),
|
||||
ScopedContext::getLocation(), map, operands)
|
||||
.getOperation();
|
||||
assert(op->getNumResults() == 1 && "Expected single result AffineApply");
|
||||
return ValueHandle(op->getResult(0));
|
||||
return op->getResult(0);
|
||||
}
|
||||
|
||||
template <typename IOp, typename FOp>
|
||||
static ValueHandle createBinaryHandle(
|
||||
ValueHandle lhs, ValueHandle rhs,
|
||||
static Value createBinaryHandle(
|
||||
Value lhs, Value rhs,
|
||||
function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
|
||||
auto thisType = lhs.getValue().getType();
|
||||
auto thatType = rhs.getValue().getType();
|
||||
auto thisType = lhs.getType();
|
||||
auto thatType = rhs.getType();
|
||||
assert(thisType == thatType && "cannot mix types in operators");
|
||||
(void)thisType;
|
||||
(void)thatType;
|
||||
if (thisType.isIndex()) {
|
||||
return createBinaryIndexHandle(lhs, rhs, affCombiner);
|
||||
} else if (thisType.isSignlessInteger()) {
|
||||
return createBinaryHandle<IOp>(lhs, rhs);
|
||||
return ValueBuilder<IOp>(lhs, rhs);
|
||||
} else if (thisType.isa<FloatType>()) {
|
||||
return createBinaryHandle<FOp>(lhs, rhs);
|
||||
return ValueBuilder<FOp>(lhs, rhs);
|
||||
} else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) {
|
||||
auto aggregateType = thisType.cast<ShapedType>();
|
||||
if (aggregateType.getElementType().isSignlessInteger())
|
||||
return createBinaryHandle<IOp>(lhs, rhs);
|
||||
return ValueBuilder<IOp>(lhs, rhs);
|
||||
else if (aggregateType.getElementType().isa<FloatType>())
|
||||
return createBinaryHandle<FOp>(lhs, rhs);
|
||||
return ValueBuilder<FOp>(lhs, rhs);
|
||||
}
|
||||
llvm_unreachable("failed to create a ValueHandle");
|
||||
llvm_unreachable("failed to create a Value");
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator+(Value lhs, Value rhs) {
|
||||
return createBinaryHandle<AddIOp, AddFOp>(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator-(Value lhs, Value rhs) {
|
||||
return createBinaryHandle<SubIOp, SubFOp>(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator*(Value lhs, Value rhs) {
|
||||
return createBinaryHandle<MulIOp, MulFOp>(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator/(Value lhs, Value rhs) {
|
||||
return createBinaryHandle<SignedDivIOp, DivFOp>(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
|
||||
llvm_unreachable("only exprs of non-index type support operator/");
|
||||
});
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator%(Value lhs, Value rhs) {
|
||||
return createBinaryHandle<SignedRemIOp, RemFOp>(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::floorDiv(Value lhs, Value rhs) {
|
||||
return createBinaryIndexHandle(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::ceilDiv(Value lhs, Value rhs) {
|
||||
return createBinaryIndexHandle(
|
||||
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::operator!(ValueHandle value) {
|
||||
assert(value.getType().isInteger(1) && "expected boolean expression");
|
||||
return ValueHandle::create<ConstantIntOp>(1, 1) - value;
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator&&(Value lhs, Value rhs) {
|
||||
assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
|
||||
assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
|
||||
return ValueHandle::create<AndOp>(lhs, rhs);
|
||||
return ValueBuilder<AndOp>(lhs, rhs);
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator||(Value lhs, Value rhs) {
|
||||
assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
|
||||
assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
|
||||
return ValueHandle::create<OrOp>(lhs, rhs);
|
||||
return ValueBuilder<OrOp>(lhs, rhs);
|
||||
}
|
||||
|
||||
static ValueHandle createIComparisonExpr(CmpIPredicate predicate,
|
||||
ValueHandle lhs, ValueHandle rhs) {
|
||||
static Value createIComparisonExpr(CmpIPredicate predicate, Value lhs,
|
||||
Value rhs) {
|
||||
auto lhsType = lhs.getType();
|
||||
auto rhsType = rhs.getType();
|
||||
(void)lhsType;
|
||||
@@ -228,13 +213,12 @@ static ValueHandle createIComparisonExpr(CmpIPredicate predicate,
|
||||
assert((lhsType.isa<IndexType>() || lhsType.isSignlessInteger()) &&
|
||||
"only integer comparisons are supported");
|
||||
|
||||
auto op = ScopedContext::getBuilder().create<CmpIOp>(
|
||||
ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
|
||||
return ValueHandle(op.getResult());
|
||||
return ScopedContext::getBuilder().create<CmpIOp>(
|
||||
ScopedContext::getLocation(), predicate, lhs, rhs);
|
||||
}
|
||||
|
||||
static ValueHandle createFComparisonExpr(CmpFPredicate predicate,
|
||||
ValueHandle lhs, ValueHandle rhs) {
|
||||
static Value createFComparisonExpr(CmpFPredicate predicate, Value lhs,
|
||||
Value rhs) {
|
||||
auto lhsType = lhs.getType();
|
||||
auto rhsType = rhs.getType();
|
||||
(void)lhsType;
|
||||
@@ -242,25 +226,24 @@ static ValueHandle createFComparisonExpr(CmpFPredicate predicate,
|
||||
assert(lhsType == rhsType && "cannot mix types in operators");
|
||||
assert(lhsType.isa<FloatType>() && "only float comparisons are supported");
|
||||
|
||||
auto op = ScopedContext::getBuilder().create<CmpFOp>(
|
||||
ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
|
||||
return ValueHandle(op.getResult());
|
||||
return ScopedContext::getBuilder().create<CmpFOp>(
|
||||
ScopedContext::getLocation(), predicate, lhs, rhs);
|
||||
}
|
||||
|
||||
// All floating point comparison are ordered through EDSL
|
||||
ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::eq(Value lhs, Value rhs) {
|
||||
auto type = lhs.getType();
|
||||
return type.isa<FloatType>()
|
||||
? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
|
||||
: createIComparisonExpr(CmpIPredicate::eq, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::ne(Value lhs, Value rhs) {
|
||||
auto type = lhs.getType();
|
||||
return type.isa<FloatType>()
|
||||
? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
|
||||
: createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator<(Value lhs, Value rhs) {
|
||||
auto type = lhs.getType();
|
||||
return type.isa<FloatType>()
|
||||
? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
|
||||
@@ -268,19 +251,19 @@ ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
|
||||
// TODO(ntv,zinenko): signed by default, how about unsigned?
|
||||
createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator<=(Value lhs, Value rhs) {
|
||||
auto type = lhs.getType();
|
||||
return type.isa<FloatType>()
|
||||
? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
|
||||
: createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator>(Value lhs, Value rhs) {
|
||||
auto type = lhs.getType();
|
||||
return type.isa<FloatType>()
|
||||
? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
|
||||
: createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
|
||||
Value mlir::edsc::op::operator>=(Value lhs, Value rhs) {
|
||||
auto type = lhs.getType();
|
||||
return type.isa<FloatType>()
|
||||
? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
|
||||
|
||||
@@ -44,14 +44,14 @@ static void insertCopyLoops(OpBuilder &builder, Location loc,
|
||||
MemRefBoundsCapture &bounds, Value from, Value to) {
|
||||
// Create EDSC handles for bounds.
|
||||
unsigned rank = bounds.rank();
|
||||
SmallVector<ValueHandle, 4> lbs, ubs, steps;
|
||||
SmallVector<Value, 4> lbs, ubs, steps;
|
||||
|
||||
// Make sure we have enough loops to use all thread dimensions, these trivial
|
||||
// loops should be outermost and therefore inserted first.
|
||||
if (rank < GPUDialect::getNumWorkgroupDimensions()) {
|
||||
unsigned extraLoops = GPUDialect::getNumWorkgroupDimensions() - rank;
|
||||
ValueHandle zero = std_constant_index(0);
|
||||
ValueHandle one = std_constant_index(1);
|
||||
Value zero = std_constant_index(0);
|
||||
Value one = std_constant_index(1);
|
||||
lbs.resize(extraLoops, zero);
|
||||
ubs.resize(extraLoops, one);
|
||||
steps.resize(extraLoops, one);
|
||||
@@ -78,9 +78,8 @@ static void insertCopyLoops(OpBuilder &builder, Location loc,
|
||||
}
|
||||
|
||||
// Produce the loop nest with copies.
|
||||
SmallVector<ValueHandle, 8> ivs(lbs.size(), ValueHandle(indexType));
|
||||
auto ivPtrs = makeHandlePointers(MutableArrayRef<ValueHandle>(ivs));
|
||||
LoopNestBuilder(ivPtrs, lbs, ubs, steps)([&]() {
|
||||
SmallVector<Value, 8> ivs(lbs.size());
|
||||
LoopNestBuilder(ivs, lbs, ubs, steps)([&]() {
|
||||
auto activeIvs = llvm::makeArrayRef(ivs).take_back(rank);
|
||||
StdIndexedValue fromHandle(from), toHandle(to);
|
||||
toHandle(activeIvs) = fromHandle(activeIvs);
|
||||
@@ -90,8 +89,8 @@ static void insertCopyLoops(OpBuilder &builder, Location loc,
|
||||
for (auto en :
|
||||
llvm::enumerate(llvm::reverse(llvm::makeArrayRef(ivs).take_back(
|
||||
GPUDialect::getNumWorkgroupDimensions())))) {
|
||||
auto loop = cast<loop::ForOp>(
|
||||
en.value().getValue().getParentRegion()->getParentOp());
|
||||
Value v = en.value();
|
||||
auto loop = cast<loop::ForOp>(v.getParentRegion()->getParentOp());
|
||||
mapLoopToProcessorIds(loop, {threadIds[en.index()]},
|
||||
{blockDims[en.index()]});
|
||||
}
|
||||
|
||||
@@ -21,69 +21,61 @@ using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
using namespace mlir::loop;
|
||||
|
||||
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
|
||||
ValueHandle range) {
|
||||
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(Value *iv, Value range) {
|
||||
assert(range.getType() && "expected !linalg.range type");
|
||||
assert(range.getValue().getDefiningOp() &&
|
||||
"need operations to extract range parts");
|
||||
auto rangeOp = cast<RangeOp>(range.getValue().getDefiningOp());
|
||||
assert(range.getDefiningOp() && "need operations to extract range parts");
|
||||
auto rangeOp = cast<RangeOp>(range.getDefiningOp());
|
||||
auto lb = rangeOp.min();
|
||||
auto ub = rangeOp.max();
|
||||
auto step = rangeOp.step();
|
||||
auto forOp = OperationHandle::createOp<ForOp>(lb, ub, step);
|
||||
*iv = ValueHandle(forOp.getInductionVar());
|
||||
*iv = forOp.getInductionVar();
|
||||
auto *body = forOp.getBody();
|
||||
enter(body, /*prev=*/1);
|
||||
}
|
||||
|
||||
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
|
||||
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(Value *iv,
|
||||
SubViewOp::Range range) {
|
||||
auto forOp =
|
||||
OperationHandle::createOp<ForOp>(range.offset, range.size, range.stride);
|
||||
*iv = ValueHandle(forOp.getInductionVar());
|
||||
*iv = forOp.getInductionVar();
|
||||
auto *body = forOp.getBody();
|
||||
enter(body, /*prev=*/1);
|
||||
}
|
||||
|
||||
ValueHandle
|
||||
mlir::edsc::LoopRangeBuilder::operator()(std::function<void(void)> fun) {
|
||||
Value mlir::edsc::LoopRangeBuilder::operator()(std::function<void(void)> fun) {
|
||||
if (fun)
|
||||
fun();
|
||||
exit();
|
||||
return ValueHandle::null();
|
||||
return Value();
|
||||
}
|
||||
|
||||
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
||||
ArrayRef<ValueHandle *> ivs, ArrayRef<SubViewOp::Range> ranges) {
|
||||
MutableArrayRef<Value> ivs, ArrayRef<SubViewOp::Range> ranges) {
|
||||
loops.reserve(ranges.size());
|
||||
for (unsigned i = 0, e = ranges.size(); i < e; ++i) {
|
||||
loops.emplace_back(ivs[i], ranges[i]);
|
||||
loops.emplace_back(&ivs[i], ranges[i]);
|
||||
}
|
||||
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
|
||||
}
|
||||
|
||||
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
||||
ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> ranges) {
|
||||
MutableArrayRef<Value> ivs, ArrayRef<Value> ranges) {
|
||||
loops.reserve(ranges.size());
|
||||
for (unsigned i = 0, e = ranges.size(); i < e; ++i) {
|
||||
loops.emplace_back(ivs[i], ranges[i]);
|
||||
loops.emplace_back(&ivs[i], ranges[i]);
|
||||
}
|
||||
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
|
||||
}
|
||||
|
||||
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
||||
ArrayRef<ValueHandle *> ivs, ArrayRef<Value> ranges)
|
||||
: LoopNestRangeBuilder(
|
||||
ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
|
||||
|
||||
ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
|
||||
Value LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
|
||||
std::function<void(void)> fun) {
|
||||
if (fun)
|
||||
fun();
|
||||
for (auto &lit : reverse(loops)) {
|
||||
lit({});
|
||||
}
|
||||
return ValueHandle::null();
|
||||
return Value();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
@@ -91,15 +83,15 @@ namespace edsc {
|
||||
|
||||
template <>
|
||||
GenericLoopNestRangeBuilder<loop::ForOp>::GenericLoopNestRangeBuilder(
|
||||
ArrayRef<edsc::ValueHandle *> ivs, ArrayRef<Value> ranges) {
|
||||
MutableArrayRef<Value> ivs, ArrayRef<Value> ranges) {
|
||||
builder = std::make_unique<LoopNestRangeBuilder>(ivs, ranges);
|
||||
}
|
||||
|
||||
template <>
|
||||
GenericLoopNestRangeBuilder<AffineForOp>::GenericLoopNestRangeBuilder(
|
||||
ArrayRef<ValueHandle *> ivs, ArrayRef<Value> ranges) {
|
||||
SmallVector<ValueHandle, 4> lbs;
|
||||
SmallVector<ValueHandle, 4> ubs;
|
||||
MutableArrayRef<Value> ivs, ArrayRef<Value> ranges) {
|
||||
SmallVector<Value, 4> lbs;
|
||||
SmallVector<Value, 4> ubs;
|
||||
SmallVector<int64_t, 4> steps;
|
||||
for (Value range : ranges) {
|
||||
assert(range.getType() && "expected linalg.range type");
|
||||
@@ -114,8 +106,8 @@ GenericLoopNestRangeBuilder<AffineForOp>::GenericLoopNestRangeBuilder(
|
||||
|
||||
template <>
|
||||
GenericLoopNestRangeBuilder<loop::ParallelOp>::GenericLoopNestRangeBuilder(
|
||||
ArrayRef<ValueHandle *> ivs, ArrayRef<Value> ranges) {
|
||||
SmallVector<ValueHandle, 4> lbs, ubs, steps;
|
||||
MutableArrayRef<Value> ivs, ArrayRef<Value> ranges) {
|
||||
SmallVector<Value, 4> lbs, ubs, steps;
|
||||
for (Value range : ranges) {
|
||||
assert(range.getType() && "expected linalg.range type");
|
||||
assert(range.getDefiningOp() && "need operations to extract range parts");
|
||||
@@ -197,10 +189,9 @@ Operation *mlir::edsc::makeGenericLinalgOp(
|
||||
OpBuilder opBuilder(op);
|
||||
ScopedContext scope(opBuilder, op->getLoc());
|
||||
BlockHandle b;
|
||||
auto handles = makeValueHandles(blockTypes);
|
||||
BlockBuilder(&b, op->getRegion(0),
|
||||
makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
|
||||
[&] { regionBuilder(b.getBlock()->getArguments()); });
|
||||
SmallVector<Value, 8> handles(blockTypes.size());
|
||||
BlockBuilder(&b, op->getRegion(0), blockTypes,
|
||||
handles)([&] { regionBuilder(b.getBlock()->getArguments()); });
|
||||
assert(op->getRegion(0).getBlocks().size() == 1);
|
||||
return op;
|
||||
}
|
||||
@@ -209,16 +200,16 @@ void mlir::edsc::ops::mulRegionBuilder(ArrayRef<BlockArgument> args) {
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
assert(args.size() == 2 && "expected 2 block arguments");
|
||||
ValueHandle a(args[0]), b(args[1]);
|
||||
linalg_yield((a * b).getValue());
|
||||
Value a(args[0]), b(args[1]);
|
||||
linalg_yield(a * b);
|
||||
}
|
||||
|
||||
void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
assert(args.size() == 3 && "expected 3 block arguments");
|
||||
ValueHandle a(args[0]), b(args[1]), c(args[2]);
|
||||
linalg_yield((c + a * b).getValue());
|
||||
Value a(args[0]), b(args[1]), c(args[2]);
|
||||
linalg_yield(c + a * b);
|
||||
}
|
||||
|
||||
Operation *mlir::edsc::ops::linalg_generic_pointwise(
|
||||
@@ -228,14 +219,14 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
|
||||
if (O.getType().isa<RankedTensorType>()) {
|
||||
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
|
||||
assert(args.size() == 1 && "expected 1 block arguments");
|
||||
ValueHandle a(args[0]);
|
||||
Value a(args[0]);
|
||||
linalg_yield(unaryOp(a));
|
||||
};
|
||||
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
|
||||
}
|
||||
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
|
||||
assert(args.size() == 2 && "expected 2 block arguments");
|
||||
ValueHandle a(args[0]);
|
||||
Value a(args[0]);
|
||||
linalg_yield(unaryOp(a));
|
||||
};
|
||||
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
|
||||
@@ -243,8 +234,7 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
|
||||
|
||||
Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I,
|
||||
StructuredIndexed O) {
|
||||
UnaryPointwiseOpBuilder unOp(
|
||||
[](ValueHandle a) -> Value { return std_tanh(a); });
|
||||
UnaryPointwiseOpBuilder unOp([](Value a) -> Value { return std_tanh(a); });
|
||||
return linalg_generic_pointwise(unOp, I, O);
|
||||
}
|
||||
|
||||
@@ -257,14 +247,14 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
|
||||
if (O.getType().isa<RankedTensorType>()) {
|
||||
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
|
||||
assert(args.size() == 2 && "expected 2 block arguments");
|
||||
ValueHandle a(args[0]), b(args[1]);
|
||||
Value a(args[0]), b(args[1]);
|
||||
linalg_yield(binaryOp(a, b));
|
||||
};
|
||||
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
|
||||
}
|
||||
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
|
||||
assert(args.size() == 3 && "expected 3 block arguments");
|
||||
ValueHandle a(args[0]), b(args[1]);
|
||||
Value a(args[0]), b(args[1]);
|
||||
linalg_yield(binaryOp(a, b));
|
||||
};
|
||||
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
|
||||
@@ -275,23 +265,22 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1,
|
||||
StructuredIndexed O) {
|
||||
using edsc::op::operator+;
|
||||
BinaryPointwiseOpBuilder binOp(
|
||||
[](ValueHandle a, ValueHandle b) -> Value { return a + b; });
|
||||
[](Value a, Value b) -> Value { return a + b; });
|
||||
return linalg_generic_pointwise(binOp, I1, I2, O);
|
||||
}
|
||||
|
||||
Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1,
|
||||
StructuredIndexed I2,
|
||||
StructuredIndexed O) {
|
||||
BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
|
||||
BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value {
|
||||
using edsc::op::operator>;
|
||||
return std_select(a > b, a, b).getValue();
|
||||
return std_select(a > b, a, b);
|
||||
});
|
||||
return linalg_generic_pointwise(binOp, I1, I2, O);
|
||||
}
|
||||
|
||||
Operation *
|
||||
mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
|
||||
ValueHandle vC,
|
||||
mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
|
||||
MatmulRegionBuilder regionBuilder) {
|
||||
// clang-format off
|
||||
AffineExpr m, n, k;
|
||||
@@ -306,8 +295,7 @@ mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
|
||||
}
|
||||
|
||||
Operation *
|
||||
mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
|
||||
RankedTensorType tC,
|
||||
mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC,
|
||||
MatmulRegionBuilder regionBuilder) {
|
||||
// clang-format off
|
||||
AffineExpr m, n, k;
|
||||
@@ -322,8 +310,8 @@ mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
|
||||
}
|
||||
|
||||
Operation *
|
||||
mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
|
||||
ValueHandle vC, RankedTensorType tD,
|
||||
mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
|
||||
RankedTensorType tD,
|
||||
MatmulRegionBuilder regionBuilder) {
|
||||
// clang-format off
|
||||
AffineExpr m, n, k;
|
||||
@@ -337,9 +325,8 @@ mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(ValueHandle vI,
|
||||
ValueHandle vW,
|
||||
ValueHandle vO,
|
||||
Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW,
|
||||
Value vO,
|
||||
ArrayRef<int> strides,
|
||||
ArrayRef<int> dilations) {
|
||||
MLIRContext *ctx = ScopedContext::getContext();
|
||||
@@ -373,8 +360,8 @@ Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(ValueHandle vI,
|
||||
}
|
||||
|
||||
Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc(
|
||||
ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
|
||||
ArrayRef<int> strides, ArrayRef<int> dilations) {
|
||||
Value vI, Value vW, Value vO, int depth_multiplier, ArrayRef<int> strides,
|
||||
ArrayRef<int> dilations) {
|
||||
MLIRContext *ctx = ScopedContext::getContext();
|
||||
// TODO(ntv) some template magic to make everything rank-polymorphic.
|
||||
assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
|
||||
|
||||
@@ -35,7 +35,7 @@ using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>;
|
||||
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
|
||||
|
||||
using llvm::dbgs;
|
||||
|
||||
|
||||
@@ -29,17 +29,16 @@ using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator==;
|
||||
using mlir::edsc::intrinsics::detail::ValueHandleArray;
|
||||
|
||||
static SmallVector<ValueHandle, 8>
|
||||
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
|
||||
ArrayRef<Value> vals) {
|
||||
static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
|
||||
Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value> vals) {
|
||||
if (map.isEmpty())
|
||||
return {};
|
||||
assert(map.getNumSymbols() == 0);
|
||||
assert(map.getNumInputs() == vals.size());
|
||||
SmallVector<ValueHandle, 8> res;
|
||||
SmallVector<Value, 8> res;
|
||||
res.reserve(map.getNumResults());
|
||||
auto dims = map.getNumDims();
|
||||
for (auto e : map.getResults()) {
|
||||
@@ -80,10 +79,10 @@ SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
static void inlineRegionAndEmitStdStore(OpType op,
|
||||
ArrayRef<Value> indexedValues,
|
||||
ArrayRef<ValueHandleArray> indexing,
|
||||
ArrayRef<Value> outputBuffers) {
|
||||
static void
|
||||
inlineRegionAndEmitStdStore(OpType op, ArrayRef<Value> indexedValues,
|
||||
ArrayRef<SmallVector<Value, 8>> indexing,
|
||||
ArrayRef<Value> outputBuffers) {
|
||||
auto &b = ScopedContext::getBuilder();
|
||||
auto &block = op.region().front();
|
||||
BlockAndValueMapping map;
|
||||
@@ -99,25 +98,27 @@ static void inlineRegionAndEmitStdStore(OpType op,
|
||||
"expected an yield op in the end of the region");
|
||||
for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
|
||||
std_store(map.lookupOrDefault(terminator.getOperand(i)), outputBuffers[i],
|
||||
indexing[i]);
|
||||
ArrayRef<Value>{indexing[i].begin(), indexing[i].end()});
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a pair that contains input indices and output indices of a
|
||||
// SingleInputPoolingOp `op`.
|
||||
struct InputAndOutputIndices {
|
||||
SmallVector<Value, 8> inputs;
|
||||
SmallVector<Value, 8> outputs;
|
||||
};
|
||||
template <typename SingleInputPoolingOp>
|
||||
static std::pair<SmallVector<ValueHandle, 8>, SmallVector<ValueHandle, 8>>
|
||||
getInputAndOutputIndices(ArrayRef<Value> allIvs, SingleInputPoolingOp op) {
|
||||
static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
|
||||
SingleInputPoolingOp op) {
|
||||
auto &b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>();
|
||||
auto maps = llvm::to_vector<8>(
|
||||
llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
|
||||
SmallVector<ValueHandle, 8> iIdx(
|
||||
makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
|
||||
SmallVector<ValueHandle, 8> oIdx(
|
||||
makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
|
||||
return {iIdx, oIdx};
|
||||
return InputAndOutputIndices{
|
||||
makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
|
||||
makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -150,8 +151,8 @@ public:
|
||||
permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
|
||||
auto outputIvs =
|
||||
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
|
||||
SmallVector<ValueHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
|
||||
SmallVector<ValueHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
|
||||
SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
|
||||
SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
|
||||
IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
|
||||
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
|
||||
// an n-D loop nest; with or without permutations.
|
||||
@@ -170,13 +171,11 @@ public:
|
||||
"expected linalg op with buffer semantics");
|
||||
auto nPar = fillOp.getNumParallelLoops();
|
||||
assert(nPar == allIvs.size());
|
||||
auto ivs =
|
||||
SmallVector<ValueHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
|
||||
auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
|
||||
IndexedValueType O(fillOp.getOutputBuffer(0));
|
||||
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
|
||||
// an n-D loop nest; with or without permutations.
|
||||
nPar > 0 ? O(ivs) = ValueHandle(fillOp.value())
|
||||
: O() = ValueHandle(fillOp.value());
|
||||
nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -187,7 +186,7 @@ public:
|
||||
assert(dotOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
assert(allIvs.size() == 1);
|
||||
ValueHandle r_i(allIvs[0]);
|
||||
Value r_i(allIvs[0]);
|
||||
IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
|
||||
C(dotOp.getOutputBuffer(0));
|
||||
// Emit scalar form.
|
||||
@@ -203,7 +202,7 @@ public:
|
||||
assert(matvecOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
assert(allIvs.size() == 2);
|
||||
ValueHandle i(allIvs[0]), r_j(allIvs[1]);
|
||||
Value i(allIvs[0]), r_j(allIvs[1]);
|
||||
IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
|
||||
C(matvecOp.getOutputBuffer(0));
|
||||
// Emit scalar form.
|
||||
@@ -219,7 +218,7 @@ public:
|
||||
assert(matmulOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
assert(allIvs.size() == 3);
|
||||
ValueHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
|
||||
Value i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
|
||||
IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
|
||||
C(matmulOp.getOutputBuffer(0));
|
||||
// Emit scalar form.
|
||||
@@ -232,16 +231,16 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
|
||||
public:
|
||||
/// Returns the input value of convOp. If the indices in `imIdx` is out of
|
||||
/// boundary, returns 0 instead.
|
||||
static ValueHandle getConvOpInput(ConvOp convOp, IndexedValueType im,
|
||||
ArrayRef<ValueHandle> imIdx) {
|
||||
static Value getConvOpInput(ConvOp convOp, IndexedValueType im,
|
||||
MutableArrayRef<Value> imIdx) {
|
||||
// TODO(ntv): add a level of indirection to linalg.generic.
|
||||
if (!convOp.padding())
|
||||
return im(imIdx);
|
||||
|
||||
auto *context = ScopedContext::getContext();
|
||||
ValueHandle zeroIndex = std_constant_index(0);
|
||||
SmallVector<ValueHandle, 8> conds;
|
||||
SmallVector<ValueHandle, 8> clampedImIdx;
|
||||
Value zeroIndex = std_constant_index(0);
|
||||
SmallVector<Value, 8> conds;
|
||||
SmallVector<Value, 8> clampedImIdx;
|
||||
for (auto iter : llvm::enumerate(imIdx)) {
|
||||
int idx = iter.index();
|
||||
auto dim = iter.value();
|
||||
@@ -254,12 +253,12 @@ public:
|
||||
using edsc::op::operator<;
|
||||
using edsc::op::operator>=;
|
||||
using edsc::op::operator||;
|
||||
ValueHandle leftOutOfBound = dim < zeroIndex;
|
||||
Value leftOutOfBound = dim < zeroIndex;
|
||||
if (conds.empty())
|
||||
conds.push_back(leftOutOfBound);
|
||||
else
|
||||
conds.push_back(conds.back() || leftOutOfBound);
|
||||
ValueHandle rightBound = std_dim(convOp.input(), idx);
|
||||
Value rightBound = std_dim(convOp.input(), idx);
|
||||
conds.push_back(conds.back() || (dim >= rightBound));
|
||||
|
||||
// When padding is involved, the indices will only be shifted to negative,
|
||||
@@ -274,10 +273,10 @@ public:
|
||||
|
||||
auto b = ScopedContext::getBuilder();
|
||||
Type type = convOp.input().getType().cast<MemRefType>().getElementType();
|
||||
ValueHandle zero = std_constant(type, b.getZeroAttr(type));
|
||||
ValueHandle readInput = im(clampedImIdx);
|
||||
Value zero = std_constant(type, b.getZeroAttr(type));
|
||||
Value readInput = im(clampedImIdx);
|
||||
return conds.empty() ? readInput
|
||||
: std_select(conds.back(), zero, readInput);
|
||||
: (Value)std_select(conds.back(), zero, readInput);
|
||||
}
|
||||
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
|
||||
@@ -288,16 +287,16 @@ public:
|
||||
auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>();
|
||||
auto maps = llvm::to_vector<8>(llvm::map_range(
|
||||
mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
|
||||
SmallVector<ValueHandle, 8> fIdx(
|
||||
SmallVector<Value, 8> fIdx(
|
||||
makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
|
||||
SmallVector<ValueHandle, 8> imIdx(
|
||||
SmallVector<Value, 8> imIdx(
|
||||
makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
|
||||
SmallVector<ValueHandle, 8> oIdx(
|
||||
SmallVector<Value, 8> oIdx(
|
||||
makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
|
||||
IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
|
||||
|
||||
// Emit scalar form.
|
||||
ValueHandle paddedInput = getConvOpInput(convOp, I, imIdx);
|
||||
Value paddedInput = getConvOpInput(convOp, I, imIdx);
|
||||
O(oIdx) += F(fIdx) * paddedInput;
|
||||
}
|
||||
};
|
||||
@@ -308,15 +307,12 @@ public:
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs,
|
||||
PoolingMaxOp op) {
|
||||
auto indices = getInputAndOutputIndices(allIvs, op);
|
||||
ValueHandleArray iIdx(indices.first);
|
||||
ValueHandleArray oIdx(indices.second);
|
||||
|
||||
// Emit scalar form.
|
||||
ValueHandle lhs = std_load(op.output(), oIdx);
|
||||
ValueHandle rhs = std_load(op.input(), iIdx);
|
||||
Value lhs = std_load(op.output(), indices.outputs);
|
||||
Value rhs = std_load(op.input(), indices.inputs);
|
||||
using edsc::op::operator>;
|
||||
ValueHandle maxValue = std_select(lhs > rhs, lhs, rhs);
|
||||
std_store(maxValue, op.output(), oIdx);
|
||||
Value maxValue = std_select(lhs > rhs, lhs, rhs);
|
||||
std_store(maxValue, op.output(), indices.outputs);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -326,15 +322,12 @@ public:
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs,
|
||||
PoolingMinOp op) {
|
||||
auto indices = getInputAndOutputIndices(allIvs, op);
|
||||
ValueHandleArray iIdx(indices.first);
|
||||
ValueHandleArray oIdx(indices.second);
|
||||
|
||||
// Emit scalar form.
|
||||
ValueHandle lhs = std_load(op.output(), oIdx);
|
||||
ValueHandle rhs = std_load(op.input(), iIdx);
|
||||
Value lhs = std_load(op.output(), indices.outputs);
|
||||
Value rhs = std_load(op.input(), indices.inputs);
|
||||
using edsc::op::operator<;
|
||||
ValueHandle minValue = std_select(lhs < rhs, lhs, rhs);
|
||||
std_store(minValue, op.output(), oIdx);
|
||||
Value minValue = std_select(lhs < rhs, lhs, rhs);
|
||||
std_store(minValue, op.output(), indices.outputs);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -344,12 +337,10 @@ public:
|
||||
static void emitScalarImplementation(ArrayRef<Value> allIvs,
|
||||
PoolingSumOp op) {
|
||||
auto indices = getInputAndOutputIndices(allIvs, op);
|
||||
SmallVector<ValueHandle, 8> iIdx = indices.first;
|
||||
SmallVector<ValueHandle, 8> oIdx = indices.second;
|
||||
IndexedValueType input(op.input()), output(op.output());
|
||||
|
||||
// Emit scalar form.
|
||||
output(oIdx) += input(iIdx);
|
||||
output(indices.outputs) += input(indices.inputs);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -392,15 +383,14 @@ public:
|
||||
"expected linalg op with buffer semantics");
|
||||
auto b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
using edsc::intrinsics::detail::ValueHandleArray;
|
||||
unsigned nInputs = genericOp.getNumInputs();
|
||||
unsigned nOutputs = genericOp.getNumOutputs();
|
||||
SmallVector<Value, 4> indexedValues(nInputs + nOutputs);
|
||||
|
||||
// 1.a. Emit std_load from input views.
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getInputIndexingMap(i), allIvs));
|
||||
auto indexing = makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getInputIndexingMap(i), allIvs);
|
||||
indexedValues[i] = std_load(genericOp.getInput(i), indexing);
|
||||
}
|
||||
|
||||
@@ -409,18 +399,18 @@ public:
|
||||
// region has no uses.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
Value output = genericOp.getOutputBuffer(i);
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
auto indexing = makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs);
|
||||
indexedValues[nInputs + i] = std_load(output, indexing);
|
||||
}
|
||||
|
||||
// TODO(ntv): When a region inliner exists, use it.
|
||||
// 2. Inline region, currently only works for a single basic block.
|
||||
// 3. Emit std_store.
|
||||
SmallVector<ValueHandleArray, 8> indexing;
|
||||
SmallVector<SmallVector<Value, 8>, 8> indexing;
|
||||
SmallVector<Value, 8> outputBuffers;
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
indexing.emplace_back(makeCanonicalAffineApplies(
|
||||
indexing.push_back(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
outputBuffers.push_back(genericOp.getOutputBuffer(i));
|
||||
}
|
||||
@@ -468,7 +458,6 @@ public:
|
||||
"expected linalg op with buffer semantics");
|
||||
auto b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
using edsc::intrinsics::detail::ValueHandleArray;
|
||||
unsigned nInputs = indexedGenericOp.getNumInputs();
|
||||
unsigned nOutputs = indexedGenericOp.getNumOutputs();
|
||||
unsigned nLoops = allIvs.size();
|
||||
@@ -481,26 +470,26 @@ public:
|
||||
// 1.a. Emit std_load from input views.
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
Value input = indexedGenericOp.getInput(i);
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
|
||||
auto indexing = makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs);
|
||||
indexedValues[nLoops + i] = std_load(input, indexing);
|
||||
}
|
||||
|
||||
// 1.b. Emit std_load from output views.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
Value output = indexedGenericOp.getOutputBuffer(i);
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
auto indexing = makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs);
|
||||
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
|
||||
}
|
||||
|
||||
// TODO(ntv): When a region inliner exists, use it.
|
||||
// 2. Inline region, currently only works for a single basic block.
|
||||
// 3. Emit std_store.
|
||||
SmallVector<ValueHandleArray, 8> indexing;
|
||||
SmallVector<SmallVector<Value, 8>, 8> indexing;
|
||||
SmallVector<Value, 8> outputBuffers;
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
indexing.emplace_back(makeCanonicalAffineApplies(
|
||||
indexing.push_back(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
|
||||
}
|
||||
@@ -533,11 +522,8 @@ public:
|
||||
typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
|
||||
AffineIndexedValue, StdIndexedValue>::type;
|
||||
static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
|
||||
MutableArrayRef<ValueHandle> allIvs) {
|
||||
SmallVector<ValueHandle *, 4> allPIvs =
|
||||
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
|
||||
|
||||
GenericLoopNestRangeBuilder<LoopTy>(allPIvs, loopRanges)([&] {
|
||||
MutableArrayRef<Value> allIvs) {
|
||||
GenericLoopNestRangeBuilder<LoopTy>(allIvs, loopRanges)([&] {
|
||||
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
|
||||
LinalgScopedEmitter<IndexedValueTy,
|
||||
ConcreteOpTy>::emitScalarImplementation(allIvValues,
|
||||
@@ -555,7 +541,7 @@ public:
|
||||
using IndexedValueTy = StdIndexedValue;
|
||||
|
||||
static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
|
||||
MutableArrayRef<ValueHandle> allIvs) {
|
||||
MutableArrayRef<Value> allIvs) {
|
||||
// Only generate loop.parallel for outer consecutive "parallel"
|
||||
// iterator_types.
|
||||
// TODO(ravishankarm): Generate loop.parallel for all "parallel" iterator
|
||||
@@ -575,24 +561,18 @@ public:
|
||||
// If there are no outer parallel loops, then number of loop ops is same as
|
||||
// the number of loops, and they are all loop.for ops.
|
||||
auto nLoopOps = (nOuterPar ? nLoops - nOuterPar + 1 : nLoops);
|
||||
SmallVector<ValueHandle *, 4> allPIvs =
|
||||
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
|
||||
|
||||
SmallVector<OperationHandle, 4> allLoops(nLoopOps, OperationHandle());
|
||||
SmallVector<OperationHandle *, 4> allPLoops;
|
||||
allPLoops.reserve(allLoops.size());
|
||||
for (OperationHandle &loop : allLoops)
|
||||
allPLoops.push_back(&loop);
|
||||
|
||||
ArrayRef<ValueHandle *> allPIvsRef(allPIvs);
|
||||
ArrayRef<OperationHandle *> allPLoopsRef(allPLoops);
|
||||
|
||||
if (nOuterPar) {
|
||||
GenericLoopNestRangeBuilder<loop::ParallelOp>(
|
||||
allPIvsRef.take_front(nOuterPar),
|
||||
loopRanges.take_front(nOuterPar))([&] {
|
||||
allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar))([&] {
|
||||
GenericLoopNestRangeBuilder<loop::ForOp>(
|
||||
allPIvsRef.drop_front(nOuterPar),
|
||||
allIvs.drop_front(nOuterPar),
|
||||
loopRanges.drop_front(nOuterPar))([&] {
|
||||
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
|
||||
LinalgScopedEmitter<StdIndexedValue, ConcreteOpTy>::
|
||||
@@ -602,7 +582,7 @@ public:
|
||||
} else {
|
||||
// If there are no parallel loops then fallback to generating all loop.for
|
||||
// operations.
|
||||
GenericLoopNestRangeBuilder<loop::ForOp>(allPIvsRef, loopRanges)([&] {
|
||||
GenericLoopNestRangeBuilder<loop::ForOp>(allIvs, loopRanges)([&] {
|
||||
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
|
||||
LinalgScopedEmitter<StdIndexedValue,
|
||||
ConcreteOpTy>::emitScalarImplementation(allIvValues,
|
||||
@@ -645,8 +625,7 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
|
||||
return LinalgLoops();
|
||||
}
|
||||
|
||||
SmallVector<ValueHandle, 4> allIvs(nLoops,
|
||||
ValueHandle(rewriter.getIndexType()));
|
||||
SmallVector<Value, 4> allIvs(nLoops);
|
||||
auto loopRanges =
|
||||
emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
|
||||
getViewSizes(rewriter, linalgOp));
|
||||
@@ -655,12 +634,12 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
|
||||
// Number of loop ops might be different from the number of ivs since some
|
||||
// loops like affine.parallel and loop.parallel have multiple ivs.
|
||||
llvm::SetVector<Operation *> loopSet;
|
||||
for (ValueHandle &iv : allIvs) {
|
||||
if (!iv.hasValue())
|
||||
for (Value iv : allIvs) {
|
||||
if (!iv)
|
||||
return {};
|
||||
// The induction variable is a block argument of the entry block of the
|
||||
// loop operation.
|
||||
BlockArgument ivVal = iv.getValue().dyn_cast<BlockArgument>();
|
||||
BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
|
||||
if (!ivVal)
|
||||
return {};
|
||||
loopSet.insert(ivVal.getOwner()->getParentOp());
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
@@ -219,10 +220,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
|
||||
SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
|
||||
Operation *op) {
|
||||
using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
|
||||
using vector_broadcast = edsc::intrinsics::ValueBuilder<vector::BroadcastOp>;
|
||||
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
|
||||
|
||||
assert(succeeded(vectorizeLinalgOpPrecondition(op)) &&
|
||||
"DRR failure case must be a precondition");
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
@@ -242,8 +239,8 @@ SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
|
||||
"]: Rewrite linalg.fill as vector.broadcast: "
|
||||
<< *op << ":\n");
|
||||
auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0));
|
||||
auto dstVec = std_load(dstMemrefVec);
|
||||
auto resVec = vector_broadcast(dstVec, fillOp.value());
|
||||
Value dstVec = std_load(dstMemrefVec);
|
||||
auto resVec = vector_broadcast(dstVec.getType(), fillOp.value());
|
||||
std_store(resVec, dstMemrefVec);
|
||||
} else {
|
||||
// Vectorize other ops as vector contraction (currently only matmul).
|
||||
|
||||
@@ -36,11 +36,11 @@ using namespace mlir::loop;
|
||||
|
||||
using llvm::SetVector;
|
||||
|
||||
using folded_affine_min = folded::ValueBuilder<AffineMinOp>;
|
||||
using folded_linalg_range = folded::ValueBuilder<linalg::RangeOp>;
|
||||
using folded_std_dim = folded::ValueBuilder<DimOp>;
|
||||
using folded_std_subview = folded::ValueBuilder<SubViewOp>;
|
||||
using folded_std_view = folded::ValueBuilder<ViewOp>;
|
||||
using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
|
||||
using folded_linalg_range = FoldedValueBuilder<linalg::RangeOp>;
|
||||
using folded_std_dim = FoldedValueBuilder<DimOp>;
|
||||
using folded_std_subview = FoldedValueBuilder<SubViewOp>;
|
||||
using folded_std_view = FoldedValueBuilder<ViewOp>;
|
||||
|
||||
#define DEBUG_TYPE "linalg-promotion"
|
||||
|
||||
@@ -74,8 +74,8 @@ static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
|
||||
if (!dynamicBuffers)
|
||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
|
||||
return std_alloc(
|
||||
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)), {},
|
||||
alignment_attr);
|
||||
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)),
|
||||
ValueRange{}, alignment_attr);
|
||||
Value mul =
|
||||
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
|
||||
return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
|
||||
@@ -118,7 +118,7 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
|
||||
auto rangeValue = en.value();
|
||||
// Try to extract a tight constant
|
||||
Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
|
||||
allocSize = folded_std_muli(folder, allocSize, size).getValue();
|
||||
allocSize = folded_std_muli(folder, allocSize, size);
|
||||
fullSizes.push_back(size);
|
||||
partialSizes.push_back(folded_std_dim(folder, subView, rank));
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
using namespace mlir::loop;
|
||||
|
||||
using folded_affine_min = folded::ValueBuilder<AffineMinOp>;
|
||||
using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
|
||||
|
||||
#define DEBUG_TYPE "linalg-tiling"
|
||||
|
||||
@@ -163,7 +163,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
|
||||
// TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices
|
||||
// does not lead to losing information.
|
||||
static void transformIndexedGenericOpIndices(
|
||||
OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
|
||||
OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
|
||||
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
|
||||
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
|
||||
auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
|
||||
@@ -193,7 +193,7 @@ static void transformIndexedGenericOpIndices(
|
||||
// Offset the index argument `i` by the value of the corresponding induction
|
||||
// variable and replace all uses of the previous value.
|
||||
Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
|
||||
pivs[rangeIndex->second]->getValue());
|
||||
ivs[rangeIndex->second]);
|
||||
for (auto &use : oldIndex.getUses()) {
|
||||
if (use.getOwner() == newIndex.getDefiningOp())
|
||||
continue;
|
||||
@@ -376,15 +376,14 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
|
||||
|
||||
// 3. Create the tiled loops.
|
||||
LinalgOp res = op;
|
||||
auto ivs = ValueHandle::makeIndexHandles(loopRanges.size());
|
||||
auto pivs = makeHandlePointers(MutableArrayRef<ValueHandle>(ivs));
|
||||
SmallVector<Value, 4> ivs(loopRanges.size());
|
||||
// Convert SubViewOp::Range to linalg_range.
|
||||
SmallVector<Value, 4> linalgRanges;
|
||||
for (auto &range : loopRanges) {
|
||||
linalgRanges.push_back(
|
||||
linalg_range(range.offset, range.size, range.stride));
|
||||
}
|
||||
GenericLoopNestRangeBuilder<LoopTy>(pivs, linalgRanges)([&] {
|
||||
GenericLoopNestRangeBuilder<LoopTy>(ivs, linalgRanges)([&] {
|
||||
auto b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());
|
||||
@@ -405,7 +404,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
|
||||
});
|
||||
|
||||
// 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
|
||||
transformIndexedGenericOpIndices(b, res, pivs, loopIndexToRangeIndex);
|
||||
transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex);
|
||||
|
||||
// 5. Gather the newly created loops and return them with the new op.
|
||||
SmallVector<Operation *, 8> loops;
|
||||
|
||||
@@ -14,8 +14,8 @@ using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
|
||||
mlir::edsc::ParallelLoopNestBuilder::ParallelLoopNestBuilder(
|
||||
ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
||||
ArrayRef<ValueHandle> ubs, ArrayRef<ValueHandle> steps) {
|
||||
MutableArrayRef<Value> ivs, ArrayRef<Value> lbs, ArrayRef<Value> ubs,
|
||||
ArrayRef<Value> steps) {
|
||||
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
|
||||
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
|
||||
assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
|
||||
@@ -36,29 +36,34 @@ void mlir::edsc::ParallelLoopNestBuilder::operator()(
|
||||
(*lit)();
|
||||
}
|
||||
|
||||
mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
|
||||
ArrayRef<ValueHandle> lbs,
|
||||
ArrayRef<ValueHandle> ubs,
|
||||
ArrayRef<ValueHandle> steps) {
|
||||
mlir::edsc::LoopNestBuilder::LoopNestBuilder(MutableArrayRef<Value> ivs,
|
||||
ArrayRef<Value> lbs,
|
||||
ArrayRef<Value> ubs,
|
||||
ArrayRef<Value> steps) {
|
||||
assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match");
|
||||
assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match");
|
||||
assert(ivs.size() == steps.size() &&
|
||||
"expected size of ivs and steps to match");
|
||||
loops.reserve(ivs.size());
|
||||
for (auto it : llvm::zip(ivs, lbs, ubs, steps))
|
||||
loops.emplace_back(makeLoopBuilder(std::get<0>(it), std::get<1>(it),
|
||||
loops.emplace_back(makeLoopBuilder(&std::get<0>(it), std::get<1>(it),
|
||||
std::get<2>(it), std::get<3>(it)));
|
||||
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
|
||||
}
|
||||
|
||||
mlir::edsc::LoopNestBuilder::LoopNestBuilder(
|
||||
ValueHandle *iv, ValueHandle lb, ValueHandle ub, ValueHandle step,
|
||||
ArrayRef<ValueHandle *> iter_args_handles,
|
||||
ValueRange iter_args_init_values) {
|
||||
assert(iter_args_init_values.size() == iter_args_handles.size() &&
|
||||
Value *iv, Value lb, Value ub, Value step,
|
||||
MutableArrayRef<Value> iterArgsHandles, ValueRange iterArgsInitValues) {
|
||||
assert(iterArgsInitValues.size() == iterArgsHandles.size() &&
|
||||
"expected size of arguments and argument_handles to match");
|
||||
loops.emplace_back(makeLoopBuilder(iv, lb, ub, step, iter_args_handles,
|
||||
iter_args_init_values));
|
||||
loops.emplace_back(
|
||||
makeLoopBuilder(iv, lb, ub, step, iterArgsHandles, iterArgsInitValues));
|
||||
}
|
||||
|
||||
mlir::edsc::LoopNestBuilder::LoopNestBuilder(Value *iv, Value lb, Value ub,
|
||||
Value step) {
|
||||
SmallVector<Value, 0> noArgs;
|
||||
loops.emplace_back(makeLoopBuilder(iv, lb, ub, step, noArgs, {}));
|
||||
}
|
||||
|
||||
Operation::result_range
|
||||
@@ -73,10 +78,10 @@ mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()(
|
||||
return loops[0].getOp()->getResults();
|
||||
}
|
||||
|
||||
LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
|
||||
ArrayRef<ValueHandle> lbHandles,
|
||||
ArrayRef<ValueHandle> ubHandles,
|
||||
ArrayRef<ValueHandle> steps) {
|
||||
LoopBuilder mlir::edsc::makeParallelLoopBuilder(MutableArrayRef<Value> ivs,
|
||||
ArrayRef<Value> lbHandles,
|
||||
ArrayRef<Value> ubHandles,
|
||||
ArrayRef<Value> steps) {
|
||||
LoopBuilder result;
|
||||
auto opHandle = OperationHandle::create<loop::ParallelOp>(
|
||||
SmallVector<Value, 4>(lbHandles.begin(), lbHandles.end()),
|
||||
@@ -86,24 +91,22 @@ LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
|
||||
loop::ParallelOp parallelOp =
|
||||
cast<loop::ParallelOp>(*opHandle.getOperation());
|
||||
for (size_t i = 0, e = ivs.size(); i < e; ++i)
|
||||
*ivs[i] = ValueHandle(parallelOp.getBody()->getArgument(i));
|
||||
ivs[i] = parallelOp.getBody()->getArgument(i);
|
||||
result.enter(parallelOp.getBody(), /*prev=*/1);
|
||||
return result;
|
||||
}
|
||||
|
||||
mlir::edsc::LoopBuilder
|
||||
mlir::edsc::makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
|
||||
ValueHandle ubHandle, ValueHandle stepHandle,
|
||||
ArrayRef<ValueHandle *> iter_args_handles,
|
||||
ValueRange iter_args_init_values) {
|
||||
mlir::edsc::LoopBuilder mlir::edsc::makeLoopBuilder(
|
||||
Value *iv, Value lbHandle, Value ubHandle, Value stepHandle,
|
||||
MutableArrayRef<Value> iterArgsHandles, ValueRange iterArgsInitValues) {
|
||||
mlir::edsc::LoopBuilder result;
|
||||
auto forOp = OperationHandle::createOp<loop::ForOp>(
|
||||
lbHandle, ubHandle, stepHandle, iter_args_init_values);
|
||||
*iv = ValueHandle(forOp.getInductionVar());
|
||||
auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody();
|
||||
for (size_t i = 0, e = iter_args_handles.size(); i < e; ++i) {
|
||||
lbHandle, ubHandle, stepHandle, iterArgsInitValues);
|
||||
*iv = forOp.getInductionVar();
|
||||
auto *body = loop::getForInductionVarOwner(*iv).getBody();
|
||||
for (size_t i = 0, e = iterArgsHandles.size(); i < e; ++i) {
|
||||
// Skipping the induction variable.
|
||||
*(iter_args_handles[i]) = ValueHandle(body->getArgument(i + 1));
|
||||
iterArgsHandles[i] = body->getArgument(i + 1);
|
||||
}
|
||||
result.setOp(forOp);
|
||||
result.enter(body, /*prev=*/1);
|
||||
|
||||
@@ -14,11 +14,11 @@ using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
|
||||
static SmallVector<ValueHandle, 8> getMemRefSizes(Value memRef) {
|
||||
static SmallVector<Value, 8> getMemRefSizes(Value memRef) {
|
||||
MemRefType memRefType = memRef.getType().cast<MemRefType>();
|
||||
assert(isStrided(memRefType) && "Expected strided MemRef type");
|
||||
|
||||
SmallVector<ValueHandle, 8> res;
|
||||
SmallVector<Value, 8> res;
|
||||
res.reserve(memRefType.getShape().size());
|
||||
const auto &shape = memRefType.getShape();
|
||||
for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) {
|
||||
|
||||
@@ -13,45 +13,29 @@ using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
|
||||
OperationHandle mlir::edsc::intrinsics::std_br(BlockHandle bh,
|
||||
ArrayRef<ValueHandle> operands) {
|
||||
ArrayRef<Value> operands) {
|
||||
assert(bh && "Expected already captured BlockHandle");
|
||||
for (auto &o : operands) {
|
||||
(void)o;
|
||||
assert(o && "Expected already captured ValueHandle");
|
||||
assert(o && "Expected already captured Value");
|
||||
}
|
||||
SmallVector<Value, 4> ops(operands.begin(), operands.end());
|
||||
return OperationHandle::create<BranchOp>(bh.getBlock(), ops);
|
||||
}
|
||||
|
||||
static void enforceEmptyCapturesMatchOperands(ArrayRef<ValueHandle *> captures,
|
||||
ArrayRef<ValueHandle> operands) {
|
||||
assert(captures.size() == operands.size() &&
|
||||
"Expected same number of captures as operands");
|
||||
for (auto it : llvm::zip(captures, operands)) {
|
||||
(void)it;
|
||||
assert(!std::get<0>(it)->hasValue() &&
|
||||
"Unexpected already captured ValueHandle");
|
||||
assert(std::get<1>(it) && "Expected already captured ValueHandle");
|
||||
assert(std::get<0>(it)->getType() == std::get<1>(it).getType() &&
|
||||
"Expected the same type for capture and operand");
|
||||
}
|
||||
}
|
||||
|
||||
OperationHandle mlir::edsc::intrinsics::std_br(BlockHandle *bh,
|
||||
ArrayRef<ValueHandle *> captures,
|
||||
ArrayRef<ValueHandle> operands) {
|
||||
ArrayRef<Type> types,
|
||||
MutableArrayRef<Value> captures,
|
||||
ArrayRef<Value> operands) {
|
||||
assert(!*bh && "Unexpected already captured BlockHandle");
|
||||
enforceEmptyCapturesMatchOperands(captures, operands);
|
||||
BlockBuilder(bh, captures)(/* no body */);
|
||||
BlockBuilder(bh, types, captures)(/* no body */);
|
||||
SmallVector<Value, 4> ops(operands.begin(), operands.end());
|
||||
return OperationHandle::create<BranchOp>(bh->getBlock(), ops);
|
||||
}
|
||||
|
||||
OperationHandle
|
||||
mlir::edsc::intrinsics::std_cond_br(ValueHandle cond, BlockHandle trueBranch,
|
||||
ArrayRef<ValueHandle> trueOperands,
|
||||
BlockHandle falseBranch,
|
||||
ArrayRef<ValueHandle> falseOperands) {
|
||||
OperationHandle mlir::edsc::intrinsics::std_cond_br(
|
||||
Value cond, BlockHandle trueBranch, ArrayRef<Value> trueOperands,
|
||||
BlockHandle falseBranch, ArrayRef<Value> falseOperands) {
|
||||
SmallVector<Value, 4> trueOps(trueOperands.begin(), trueOperands.end());
|
||||
SmallVector<Value, 4> falseOps(falseOperands.begin(), falseOperands.end());
|
||||
return OperationHandle::create<CondBranchOp>(
|
||||
@@ -59,16 +43,14 @@ mlir::edsc::intrinsics::std_cond_br(ValueHandle cond, BlockHandle trueBranch,
|
||||
}
|
||||
|
||||
OperationHandle mlir::edsc::intrinsics::std_cond_br(
|
||||
ValueHandle cond, BlockHandle *trueBranch,
|
||||
ArrayRef<ValueHandle *> trueCaptures, ArrayRef<ValueHandle> trueOperands,
|
||||
BlockHandle *falseBranch, ArrayRef<ValueHandle *> falseCaptures,
|
||||
ArrayRef<ValueHandle> falseOperands) {
|
||||
Value cond, BlockHandle *trueBranch, ArrayRef<Type> trueTypes,
|
||||
MutableArrayRef<Value> trueCaptures, ArrayRef<Value> trueOperands,
|
||||
BlockHandle *falseBranch, ArrayRef<Type> falseTypes,
|
||||
MutableArrayRef<Value> falseCaptures, ArrayRef<Value> falseOperands) {
|
||||
assert(!*trueBranch && "Unexpected already captured BlockHandle");
|
||||
assert(!*falseBranch && "Unexpected already captured BlockHandle");
|
||||
enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands);
|
||||
enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands);
|
||||
BlockBuilder(trueBranch, trueCaptures)(/* no body */);
|
||||
BlockBuilder(falseBranch, falseCaptures)(/* no body */);
|
||||
BlockBuilder(trueBranch, trueTypes, trueCaptures)(/* no body */);
|
||||
BlockBuilder(falseBranch, falseTypes, falseCaptures)(/* no body */);
|
||||
SmallVector<Value, 4> trueOps(trueOperands.begin(), trueOperands.end());
|
||||
SmallVector<Value, 4> falseOps(falseOperands.begin(), falseOperands.end());
|
||||
return OperationHandle::create<CondBranchOp>(
|
||||
|
||||
@@ -65,25 +65,8 @@ MLIRContext *mlir::edsc::ScopedContext::getContext() {
|
||||
return getBuilder().getContext();
|
||||
}
|
||||
|
||||
ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) {
|
||||
assert(t == other.t && "Wrong type capture");
|
||||
assert(!v && "ValueHandle has already been captured, use a new name!");
|
||||
v = other.v;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ValueHandle ValueHandle::create(StringRef name, ArrayRef<ValueHandle> operands,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
Operation *op =
|
||||
OperationHandle::create(name, operands, resultTypes, attributes);
|
||||
if (op->getNumResults() == 1)
|
||||
return ValueHandle(op->getResult(0));
|
||||
llvm_unreachable("unsupported operation, use an OperationHandle instead");
|
||||
}
|
||||
|
||||
OperationHandle OperationHandle::create(StringRef name,
|
||||
ArrayRef<ValueHandle> operands,
|
||||
ArrayRef<Value> operands,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
OperationState state(ScopedContext::getLocation(), name);
|
||||
@@ -156,37 +139,32 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) {
|
||||
enter(bh.getBlock());
|
||||
}
|
||||
|
||||
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
|
||||
ArrayRef<ValueHandle *> args) {
|
||||
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, ArrayRef<Type> types,
|
||||
MutableArrayRef<Value> args) {
|
||||
assert(!*bh && "BlockHandle already captures a block, use "
|
||||
"the explicit BockBuilder(bh, Append())({}) syntax instead.");
|
||||
SmallVector<Type, 8> types;
|
||||
for (auto *a : args) {
|
||||
assert(!a->hasValue() &&
|
||||
"Expected delayed ValueHandle that has not yet captured.");
|
||||
types.push_back(a->getType());
|
||||
}
|
||||
assert((args.empty() || args.size() == types.size()) &&
|
||||
"if args captures are specified, their number must match the number "
|
||||
"of types");
|
||||
*bh = BlockHandle::create(types);
|
||||
for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) {
|
||||
*(std::get<0>(it)) = ValueHandle(std::get<1>(it));
|
||||
}
|
||||
if (!args.empty())
|
||||
for (auto it : llvm::zip(args, bh->getBlock()->getArguments()))
|
||||
std::get<0>(it) = Value(std::get<1>(it));
|
||||
enter(bh->getBlock());
|
||||
}
|
||||
|
||||
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, Region ®ion,
|
||||
ArrayRef<ValueHandle *> args) {
|
||||
ArrayRef<Type> types,
|
||||
MutableArrayRef<Value> args) {
|
||||
assert(!*bh && "BlockHandle already captures a block, use "
|
||||
"the explicit BockBuilder(bh, Append())({}) syntax instead.");
|
||||
SmallVector<Type, 8> types;
|
||||
for (auto *a : args) {
|
||||
assert(!a->hasValue() &&
|
||||
"Expected delayed ValueHandle that has not yet captured.");
|
||||
types.push_back(a->getType());
|
||||
}
|
||||
assert((args.empty() || args.size() == types.size()) &&
|
||||
"if args captures are specified, their number must match the number "
|
||||
"of types");
|
||||
*bh = BlockHandle::createInRegion(region, types);
|
||||
for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) {
|
||||
*(std::get<0>(it)) = ValueHandle(std::get<1>(it));
|
||||
}
|
||||
if (!args.empty())
|
||||
for (auto it : llvm::zip(args, bh->getBlock()->getArguments()))
|
||||
std::get<0>(it) = Value(std::get<1>(it));
|
||||
enter(bh->getBlock());
|
||||
}
|
||||
|
||||
|
||||
@@ -68,12 +68,11 @@ TEST_FUNC(builder_dynamic_for_func_args) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)),
|
||||
ub(f.getArgument(1));
|
||||
ValueHandle f7(std_constant_float(llvm::APFloat(7.0f), f32Type));
|
||||
ValueHandle f13(std_constant_float(llvm::APFloat(13.0f), f32Type));
|
||||
ValueHandle i7(std_constant_int(7, 32));
|
||||
ValueHandle i13(std_constant_int(13, 32));
|
||||
Value i, j, lb(f.getArgument(0)), ub(f.getArgument(1));
|
||||
Value f7(std_constant_float(llvm::APFloat(7.0f), f32Type));
|
||||
Value f13(std_constant_float(llvm::APFloat(13.0f), f32Type));
|
||||
Value i7(std_constant_int(7, 32));
|
||||
Value i13(std_constant_int(13, 32));
|
||||
AffineLoopNestBuilder(&i, lb, ub, 3)([&] {
|
||||
using namespace edsc::op;
|
||||
lb *std_constant_index(3) + ub;
|
||||
@@ -119,8 +118,8 @@ TEST_FUNC(builder_dynamic_for) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
|
||||
c(f.getArgument(2)), d(f.getArgument(3));
|
||||
Value i, a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)),
|
||||
d(f.getArgument(3));
|
||||
using namespace edsc::op;
|
||||
AffineLoopNestBuilder(&i, a - b, c + d, 2)();
|
||||
|
||||
@@ -141,8 +140,8 @@ TEST_FUNC(builder_loop_for) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
|
||||
c(f.getArgument(2)), d(f.getArgument(3));
|
||||
Value i, a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)),
|
||||
d(f.getArgument(3));
|
||||
using namespace edsc::op;
|
||||
LoopNestBuilder(&i, a - b, c + d, a)();
|
||||
|
||||
@@ -163,8 +162,8 @@ TEST_FUNC(builder_max_min_for) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)),
|
||||
ub1(f.getArgument(2)), ub2(f.getArgument(3));
|
||||
Value i, lb1(f.getArgument(0)), lb2(f.getArgument(1)), ub1(f.getArgument(2)),
|
||||
ub2(f.getArgument(3));
|
||||
AffineLoopNestBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)();
|
||||
std_ret();
|
||||
|
||||
@@ -183,17 +182,20 @@ TEST_FUNC(builder_blocks) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
|
||||
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
|
||||
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
|
||||
arg4(c1.getType()), r(c1.getType());
|
||||
|
||||
Value c1(std_constant_int(42, 32)), c2(std_constant_int(1234, 32));
|
||||
Value r;
|
||||
Value args12[2];
|
||||
Value &arg1 = args12[0], &arg2 = args12[1];
|
||||
Value args34[2];
|
||||
Value &arg3 = args34[0], &arg4 = args34[1];
|
||||
BlockHandle b1, b2, functionBlock(&f.front());
|
||||
BlockBuilder(&b1, {&arg1, &arg2})(
|
||||
BlockBuilder(&b1, {c1.getType(), c1.getType()}, args12)(
|
||||
// b2 has not yet been constructed, need to come back later.
|
||||
// This is a byproduct of non-structured control-flow.
|
||||
);
|
||||
BlockBuilder(&b2, {&arg3, &arg4})([&] { std_br(b1, {arg3, arg4}); });
|
||||
BlockBuilder(&b2, {c1.getType(), c1.getType()}, args34)([&] {
|
||||
std_br(b1, {arg3, arg4});
|
||||
});
|
||||
// The insertion point within the toplevel function is now past b2, we will
|
||||
// need to get back the entry block.
|
||||
// This is what happens with unstructured control-flow..
|
||||
@@ -226,24 +228,25 @@ TEST_FUNC(builder_blocks_eager) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
|
||||
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
|
||||
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
|
||||
arg4(c1.getType()), r(c1.getType());
|
||||
Value c1(std_constant_int(42, 32)), c2(std_constant_int(1234, 32));
|
||||
Value res;
|
||||
Value args1And2[2], args3And4[2];
|
||||
Value &arg1 = args1And2[0], &arg2 = args1And2[1], &arg3 = args3And4[0],
|
||||
&arg4 = args3And4[1];
|
||||
|
||||
// clang-format off
|
||||
BlockHandle b1, b2;
|
||||
{ // Toplevel function scope.
|
||||
// Build a new block for b1 eagerly.
|
||||
std_br(&b1, {&arg1, &arg2}, {c1, c2});
|
||||
std_br(&b1, {c1.getType(), c1.getType()}, args1And2, {c1, c2});
|
||||
// Construct a new block b2 explicitly with a branch into b1.
|
||||
BlockBuilder(&b2, {&arg3, &arg4})([&]{
|
||||
BlockBuilder(&b2, {c1.getType(), c1.getType()}, args3And4)([&]{
|
||||
std_br(b1, {arg3, arg4});
|
||||
});
|
||||
/// And come back to append into b1 once b2 exists.
|
||||
BlockBuilder(b1, Append())([&]{
|
||||
r = arg1 + arg2;
|
||||
std_br(b2, {arg1, r});
|
||||
res = arg1 + arg2;
|
||||
std_br(b2, {arg1, res});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -268,15 +271,14 @@ TEST_FUNC(builder_cond_branch) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle funcArg(f.getArgument(0));
|
||||
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
|
||||
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
|
||||
c42(ValueHandle::create<ConstantIntOp>(42, 32));
|
||||
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
|
||||
|
||||
Value funcArg(f.getArgument(0));
|
||||
Value c32(std_constant_int(32, 32)), c64(std_constant_int(64, 64)),
|
||||
c42(std_constant_int(42, 32));
|
||||
Value arg1;
|
||||
Value args23[2];
|
||||
BlockHandle b1, b2, functionBlock(&f.front());
|
||||
BlockBuilder(&b1, {&arg1})([&] { std_ret(); });
|
||||
BlockBuilder(&b2, {&arg2, &arg3})([&] { std_ret(); });
|
||||
BlockBuilder(&b1, c32.getType(), arg1)([&] { std_ret(); });
|
||||
BlockBuilder(&b2, {c64.getType(), c32.getType()}, args23)([&] { std_ret(); });
|
||||
// Get back to entry block and add a conditional branch
|
||||
BlockBuilder(functionBlock, Append())([&] {
|
||||
std_cond_br(funcArg, b1, {c32}, b2, {c64, c42});
|
||||
@@ -304,15 +306,16 @@ TEST_FUNC(builder_cond_branch_eager) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle funcArg(f.getArgument(0));
|
||||
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
|
||||
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
|
||||
c42(ValueHandle::create<ConstantIntOp>(42, 32));
|
||||
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
|
||||
Value arg0(f.getArgument(0));
|
||||
Value c32(std_constant_int(32, 32)), c64(std_constant_int(64, 64)),
|
||||
c42(std_constant_int(42, 32));
|
||||
|
||||
// clang-format off
|
||||
BlockHandle b1, b2;
|
||||
std_cond_br(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42});
|
||||
Value arg1[1], args2And3[2];
|
||||
std_cond_br(arg0,
|
||||
&b1, c32.getType(), arg1, c32,
|
||||
&b2, {c64.getType(), c32.getType()}, args2And3, {c64, c42});
|
||||
BlockBuilder(b1, Append())([]{
|
||||
std_ret();
|
||||
});
|
||||
@@ -336,7 +339,6 @@ TEST_FUNC(builder_cond_branch_eager) {
|
||||
|
||||
TEST_FUNC(builder_helpers) {
|
||||
using namespace edsc::op;
|
||||
auto indexType = IndexType::get(&globalContext());
|
||||
auto f32Type = FloatType::getF32(&globalContext());
|
||||
auto memrefType =
|
||||
MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
|
||||
@@ -348,21 +350,20 @@ TEST_FUNC(builder_helpers) {
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle f7(
|
||||
ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type));
|
||||
Value f7 = std_constant_float(llvm::APFloat(7.0f), f32Type);
|
||||
MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)),
|
||||
vC(f.getArgument(2));
|
||||
AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
|
||||
ValueHandle i(indexType), j(indexType), k1(indexType), k2(indexType),
|
||||
lb0(indexType), lb1(indexType), lb2(indexType),
|
||||
ub0(indexType), ub1(indexType), ub2(indexType);
|
||||
Value ivs[2];
|
||||
Value &i = ivs[0], &j = ivs[1];
|
||||
Value k1, k2, lb0, lb1, lb2, ub0, ub1, ub2;
|
||||
int64_t step0, step1, step2;
|
||||
std::tie(lb0, ub0, step0) = vA.range(0);
|
||||
std::tie(lb1, ub1, step1) = vA.range(1);
|
||||
lb2 = vA.lb(2);
|
||||
ub2 = vA.ub(2);
|
||||
step2 = vA.step(2);
|
||||
AffineLoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})([&]{
|
||||
AffineLoopNestBuilder(ivs, {lb0, lb1}, {ub0, ub1}, {step0, step1})([&]{
|
||||
AffineLoopNestBuilder(&k1, lb2, ub2, step2)([&]{
|
||||
C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1);
|
||||
});
|
||||
@@ -393,45 +394,6 @@ TEST_FUNC(builder_helpers) {
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(custom_ops) {
|
||||
using namespace edsc::op;
|
||||
auto indexType = IndexType::get(&globalContext());
|
||||
auto f = makeFunction("custom_ops", {}, {indexType, indexType});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
CustomOperation<ValueHandle> MY_CUSTOM_OP("my_custom_op");
|
||||
CustomOperation<OperationHandle> MY_CUSTOM_OP_0("my_custom_op_0");
|
||||
CustomOperation<OperationHandle> MY_CUSTOM_OP_2("my_custom_op_2");
|
||||
|
||||
// clang-format off
|
||||
ValueHandle vh(indexType), vh20(indexType), vh21(indexType);
|
||||
OperationHandle ih0, ih2;
|
||||
ValueHandle m(indexType), n(indexType);
|
||||
ValueHandle M(f.getArgument(0)), N(f.getArgument(1));
|
||||
ValueHandle ten(std_constant_index(10)), twenty(std_constant_index(20));
|
||||
AffineLoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{
|
||||
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {});
|
||||
ih0 = MY_CUSTOM_OP_0({m, m + n}, {});
|
||||
ih2 = MY_CUSTOM_OP_2({m, m + n}, {indexType, indexType});
|
||||
// These captures are verbose for now, can improve when used in practice.
|
||||
vh20 = ValueHandle(ih2.getOperation()->getResult(0));
|
||||
vh21 = ValueHandle(ih2.getOperation()->getResult(1));
|
||||
MY_CUSTOM_OP({vh20, vh21}, {indexType}, {});
|
||||
});
|
||||
|
||||
// CHECK-LABEL: @custom_ops
|
||||
// CHECK: affine.for %{{.*}} {{.*}}
|
||||
// CHECK: affine.for %{{.*}} {{.*}}
|
||||
// CHECK: {{.*}} = "my_custom_op"{{.*}} : (index, index) -> index
|
||||
// CHECK: "my_custom_op_0"{{.*}} : (index, index) -> ()
|
||||
// CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index)
|
||||
// CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index
|
||||
// clang-format on
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(insertion_in_block) {
|
||||
using namespace edsc::op;
|
||||
auto indexType = IndexType::get(&globalContext());
|
||||
@@ -441,11 +403,11 @@ TEST_FUNC(insertion_in_block) {
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
BlockHandle b1;
|
||||
// clang-format off
|
||||
ValueHandle::create<ConstantIntOp>(0, 32);
|
||||
BlockBuilder(&b1, {})([]{
|
||||
ValueHandle::create<ConstantIntOp>(1, 32);
|
||||
std_constant_int(0, 32);
|
||||
(BlockBuilder(&b1))([]{
|
||||
std_constant_int(1, 32);
|
||||
});
|
||||
ValueHandle::create<ConstantIntOp>(2, 32);
|
||||
std_constant_int(2, 32);
|
||||
// CHECK-LABEL: @insertion_in_block
|
||||
// CHECK: {{.*}} = constant 0 : i32
|
||||
// CHECK: {{.*}} = constant 2 : i32
|
||||
@@ -469,8 +431,8 @@ TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
|
||||
AffineIndexedValue A(f.getArgument(0));
|
||||
AffineIndexedValue B(f.getArgument(1));
|
||||
// clang-format off
|
||||
edsc::intrinsics::std_zero_extendi(*A, i8Type);
|
||||
edsc::intrinsics::std_sign_extendi(*B, i8Type);
|
||||
edsc::intrinsics::std_zero_extendi(A, i8Type);
|
||||
edsc::intrinsics::std_sign_extendi(B, i8Type);
|
||||
// CHECK-LABEL: @zero_and_std_sign_extendi_op
|
||||
// CHECK: %[[SRC1:.*]] = affine.load
|
||||
// CHECK: zexti %[[SRC1]] : i1 to i8
|
||||
@@ -489,8 +451,8 @@ TEST_FUNC(operator_or) {
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
using op::operator||;
|
||||
ValueHandle lhs(f.getArgument(0));
|
||||
ValueHandle rhs(f.getArgument(1));
|
||||
Value lhs(f.getArgument(0));
|
||||
Value rhs(f.getArgument(1));
|
||||
lhs || rhs;
|
||||
|
||||
// CHECK-LABEL: @operator_or
|
||||
@@ -508,8 +470,8 @@ TEST_FUNC(operator_and) {
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
using op::operator&&;
|
||||
ValueHandle lhs(f.getArgument(0));
|
||||
ValueHandle rhs(f.getArgument(1));
|
||||
Value lhs(f.getArgument(0));
|
||||
Value rhs(f.getArgument(1));
|
||||
lhs &&rhs;
|
||||
|
||||
// CHECK-LABEL: @operator_and
|
||||
@@ -521,7 +483,6 @@ TEST_FUNC(operator_and) {
|
||||
|
||||
TEST_FUNC(select_op_i32) {
|
||||
using namespace edsc::op;
|
||||
auto indexType = IndexType::get(&globalContext());
|
||||
auto f32Type = FloatType::getF32(&globalContext());
|
||||
auto memrefType = MemRefType::get(
|
||||
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
|
||||
@@ -530,17 +491,13 @@ TEST_FUNC(select_op_i32) {
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle zero = std_constant_index(0), one = std_constant_index(1);
|
||||
Value zero = std_constant_index(0), one = std_constant_index(1);
|
||||
MemRefBoundsCapture vA(f.getArgument(0));
|
||||
AffineIndexedValue A(f.getArgument(0));
|
||||
ValueHandle i(indexType), j(indexType);
|
||||
AffineLoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
|
||||
// This test exercises AffineIndexedValue::operator Value.
|
||||
// Without it, one must force conversion to ValueHandle as such:
|
||||
// std_select(
|
||||
// i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j)))
|
||||
using edsc::op::operator==;
|
||||
std_select(i == zero, *A(zero, zero), *A(i, j));
|
||||
Value ivs[2];
|
||||
Value &i = ivs[0], &j = ivs[1];
|
||||
AffineLoopNestBuilder(ivs, {zero, zero}, {one, one}, {1, 1})([&]{
|
||||
std_select(eq(i, zero), A(zero, zero), A(i, j));
|
||||
});
|
||||
|
||||
// CHECK-LABEL: @select_op
|
||||
@@ -556,7 +513,6 @@ TEST_FUNC(select_op_i32) {
|
||||
}
|
||||
|
||||
TEST_FUNC(select_op_f32) {
|
||||
auto indexType = IndexType::get(&globalContext());
|
||||
auto f32Type = FloatType::getF32(&globalContext());
|
||||
auto memrefType = MemRefType::get(
|
||||
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
|
||||
@@ -565,18 +521,19 @@ TEST_FUNC(select_op_f32) {
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle zero = std_constant_index(0), one = std_constant_index(1);
|
||||
Value zero = std_constant_index(0), one = std_constant_index(1);
|
||||
MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1));
|
||||
AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1));
|
||||
ValueHandle i(indexType), j(indexType);
|
||||
AffineLoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
|
||||
Value ivs[2];
|
||||
Value &i = ivs[0], &j = ivs[1];
|
||||
AffineLoopNestBuilder(ivs, {zero, zero}, {one, one}, {1, 1})([&]{
|
||||
using namespace edsc::op;
|
||||
std_select(B(i, j) == B(i + one, j), *A(zero, zero), *A(i, j));
|
||||
std_select(B(i, j) != B(i + one, j), *A(zero, zero), *A(i, j));
|
||||
std_select(B(i, j) >= B(i + one, j), *A(zero, zero), *A(i, j));
|
||||
std_select(B(i, j) <= B(i + one, j), *A(zero, zero), *A(i, j));
|
||||
std_select(B(i, j) < B(i + one, j), *A(zero, zero), *A(i, j));
|
||||
std_select(B(i, j) > B(i + one, j), *A(zero, zero), *A(i, j));
|
||||
std_select(eq(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
|
||||
std_select(ne(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
|
||||
std_select(B(i, j) >= B(i + one, j), A(zero, zero), A(i, j));
|
||||
std_select(B(i, j) <= B(i + one, j), A(zero, zero), A(i, j));
|
||||
std_select(B(i, j) < B(i + one, j), A(zero, zero), A(i, j));
|
||||
std_select(B(i, j) > B(i + one, j), A(zero, zero), A(i, j));
|
||||
});
|
||||
|
||||
// CHECK-LABEL: @select_op
|
||||
@@ -632,7 +589,6 @@ TEST_FUNC(select_op_f32) {
|
||||
// Inject an EDSC-constructed computation to exercise imperfectly nested 2-d
|
||||
// tiling.
|
||||
TEST_FUNC(tile_2d) {
|
||||
auto indexType = IndexType::get(&globalContext());
|
||||
auto memrefType =
|
||||
MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
|
||||
ShapedType::kDynamicSize},
|
||||
@@ -641,17 +597,19 @@ TEST_FUNC(tile_2d) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle zero = std_constant_index(0);
|
||||
Value zero = std_constant_index(0);
|
||||
MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)),
|
||||
vC(f.getArgument(2));
|
||||
AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)),
|
||||
C(f.getArgument(2));
|
||||
ValueHandle i(indexType), j(indexType), k1(indexType), k2(indexType);
|
||||
ValueHandle M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
|
||||
Value ivs[2];
|
||||
Value &i = ivs[0], &j = ivs[1];
|
||||
Value k1, k2;
|
||||
Value M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
|
||||
|
||||
// clang-format off
|
||||
using namespace edsc::op;
|
||||
AffineLoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
|
||||
AffineLoopNestBuilder(ivs, {zero, zero}, {M, N}, {1, 1})([&]{
|
||||
AffineLoopNestBuilder(&k1, zero, O, 1)([&]{
|
||||
C(i, j, k1) = A(i, j, k1) + B(i, j, k1);
|
||||
});
|
||||
@@ -661,10 +619,8 @@ TEST_FUNC(tile_2d) {
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
auto li = getForInductionVarOwner(i.getValue()),
|
||||
lj = getForInductionVarOwner(j.getValue()),
|
||||
lk1 = getForInductionVarOwner(k1.getValue()),
|
||||
lk2 = getForInductionVarOwner(k2.getValue());
|
||||
auto li = getForInductionVarOwner(i), lj = getForInductionVarOwner(j),
|
||||
lk1 = getForInductionVarOwner(k1), lk2 = getForInductionVarOwner(k2);
|
||||
auto indicesL1 = mlir::tile({li, lj}, {512, 1024}, {lk1, lk2});
|
||||
auto lii1 = indicesL1[0][0], ljj1 = indicesL1[1][0];
|
||||
mlir::tile({ljj1, lii1}, {32, 16}, ljj1);
|
||||
@@ -713,15 +669,15 @@ TEST_FUNC(indirect_access) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle zero = std_constant_index(0);
|
||||
Value zero = std_constant_index(0);
|
||||
MemRefBoundsCapture vC(f.getArgument(2));
|
||||
AffineIndexedValue B(f.getArgument(1)), D(f.getArgument(3));
|
||||
StdIndexedValue A(f.getArgument(0)), C(f.getArgument(2));
|
||||
ValueHandle i(builder.getIndexType()), N(vC.ub(0));
|
||||
Value i, N(vC.ub(0));
|
||||
|
||||
// clang-format off
|
||||
AffineLoopNestBuilder(&i, zero, N, 1)([&]{
|
||||
C((ValueHandle)D(i)) = A((ValueHandle)B(i));
|
||||
C((Value)D(i)) = A((Value)B(i));
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
@@ -747,12 +703,12 @@ TEST_FUNC(empty_map_load_store) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle zero = std_constant_index(0);
|
||||
ValueHandle one = std_constant_index(1);
|
||||
Value zero = std_constant_index(0);
|
||||
Value one = std_constant_index(1);
|
||||
AffineIndexedValue input(f.getArgument(0)), res(f.getArgument(1));
|
||||
ValueHandle iv(builder.getIndexType());
|
||||
|
||||
// clang-format off
|
||||
Value iv;
|
||||
AffineLoopNestBuilder(&iv, zero, one, 1)([&]{
|
||||
res() = input();
|
||||
});
|
||||
@@ -784,7 +740,7 @@ TEST_FUNC(affine_if_op) {
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
ValueHandle zero = std_constant_index(0), ten = std_constant_index(10);
|
||||
Value zero = std_constant_index(0), ten = std_constant_index(10);
|
||||
|
||||
SmallVector<bool, 4> isEq = {false, false, false, false};
|
||||
SmallVector<AffineExpr, 4> affineExprs = {
|
||||
@@ -834,7 +790,7 @@ TEST_FUNC(linalg_generic_pointwise_test) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
|
||||
Value A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
|
||||
AffineExpr i, j;
|
||||
bindDims(&globalContext(), i, j);
|
||||
StructuredIndexed SA(A), SB(B), SC(C);
|
||||
@@ -864,12 +820,12 @@ TEST_FUNC(linalg_generic_matmul_test) {
|
||||
auto f32Type = FloatType::getF32(&globalContext());
|
||||
auto memrefType = MemRefType::get(
|
||||
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
|
||||
auto f =
|
||||
makeFunction("linalg_generic_matmul", {}, {memrefType, memrefType, memrefType});
|
||||
auto f = makeFunction("linalg_generic_matmul", {},
|
||||
{memrefType, memrefType, memrefType});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
linalg_generic_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())));
|
||||
linalg_generic_matmul(f.getArguments());
|
||||
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
@@ -902,8 +858,8 @@ TEST_FUNC(linalg_generic_conv_nhwc) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
linalg_generic_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())),
|
||||
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
|
||||
linalg_generic_conv_nhwc(f.getArguments(),
|
||||
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
|
||||
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
@@ -936,9 +892,9 @@ TEST_FUNC(linalg_generic_dilated_conv_nhwc) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
linalg_generic_dilated_conv_nhwc(makeValueHandles(f.getArguments()),
|
||||
/*depth_multiplier=*/7,
|
||||
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
|
||||
linalg_generic_dilated_conv_nhwc(f.getArguments(),
|
||||
/*depth_multiplier=*/7,
|
||||
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
|
||||
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
@@ -958,7 +914,7 @@ TEST_FUNC(linalg_metadata_ops) {
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
AffineExpr i, j, k;
|
||||
bindDims(&globalContext(), i, j, k);
|
||||
ValueHandle v(f.getArgument(0));
|
||||
Value v(f.getArgument(0));
|
||||
auto reshaped = linalg_reshape(v, ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
|
||||
linalg_reshape(memrefType, reshaped,
|
||||
ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
|
||||
@@ -1015,7 +971,7 @@ TEST_FUNC(linalg_tensors_test) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle A(f.getArgument(0)), B(f.getArgument(1));
|
||||
Value A(f.getArgument(0)), B(f.getArgument(1));
|
||||
AffineExpr i, j;
|
||||
bindDims(&globalContext(), i, j);
|
||||
StructuredIndexed SA(A), SB(B), SC(tensorType);
|
||||
@@ -1023,7 +979,7 @@ TEST_FUNC(linalg_tensors_test) {
|
||||
linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
|
||||
linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j}));
|
||||
Value o1 = linalg_generic_matmul(A, B, tensorType)->getResult(0);
|
||||
linalg_generic_matmul(A, B, ValueHandle(o1), tensorType);
|
||||
linalg_generic_matmul(A, B, o1, tensorType);
|
||||
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
@@ -1064,7 +1020,7 @@ TEST_FUNC(memref_vector_matmul_test) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
|
||||
Value A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
|
||||
auto contractionBuilder = [](ArrayRef<BlockArgument> args) {
|
||||
assert(args.size() == 3 && "expected 3 block arguments");
|
||||
(linalg_yield(vector_contraction_matmul(args[0], args[1], args[2])));
|
||||
@@ -1083,19 +1039,19 @@ TEST_FUNC(builder_loop_for_yield) {
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle init0 = std_constant_float(llvm::APFloat(1.0f), f32Type);
|
||||
ValueHandle init1 = std_constant_float(llvm::APFloat(2.0f), f32Type);
|
||||
ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
|
||||
c(f.getArgument(2)), d(f.getArgument(3));
|
||||
ValueHandle arg0(f32Type);
|
||||
ValueHandle arg1(f32Type);
|
||||
Value init0 = std_constant_float(llvm::APFloat(1.0f), f32Type);
|
||||
Value init1 = std_constant_float(llvm::APFloat(2.0f), f32Type);
|
||||
Value i, a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)),
|
||||
d(f.getArgument(3));
|
||||
Value args01[2];
|
||||
Value &arg0 = args01[0], &arg1 = args01[1];
|
||||
using namespace edsc::op;
|
||||
auto results =
|
||||
LoopNestBuilder(&i, a - b, c + d, a, {&arg0, &arg1}, {init0, init1})([&] {
|
||||
LoopNestBuilder(&i, a - b, c + d, a, args01, {init0, init1})([&] {
|
||||
auto sum = arg0 + arg1;
|
||||
loop_yield(ArrayRef<ValueHandle>{arg1, sum});
|
||||
loop_yield(ArrayRef<Value>{arg1, sum});
|
||||
});
|
||||
ValueHandle(results[0]) + ValueHandle(results[1]);
|
||||
results[0] + results[1];
|
||||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @builder_loop_for_yield(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
|
||||
@@ -16,9 +16,9 @@
|
||||
// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) };
|
||||
//
|
||||
// IMPL: Test1Op::regionBuilder(Block &block) {
|
||||
// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||
// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
|
||||
// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
|
||||
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
|
||||
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
|
||||
// IMPL: (linalg_yield(ValueRange{ [[e]] }));
|
||||
//
|
||||
ods_def<Test1Op> :
|
||||
@@ -41,9 +41,9 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
|
||||
// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) };
|
||||
//
|
||||
// IMPL: Test2Op::regionBuilder(Block &block) {
|
||||
// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||
// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
|
||||
// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
|
||||
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
|
||||
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
|
||||
// IMPL: (linalg_yield(ValueRange{ [[e]] }));
|
||||
//
|
||||
ods_def<Test2Op> :
|
||||
@@ -66,9 +66,9 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
|
||||
// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) };
|
||||
//
|
||||
// IMPL: Test3Op::regionBuilder(Block &block) {
|
||||
// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||
// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
|
||||
// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
|
||||
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
|
||||
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
|
||||
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
|
||||
// IMPL: (linalg_yield(ValueRange{ [[e]] }));
|
||||
//
|
||||
ods_def<Test3Op> :
|
||||
|
||||
@@ -1601,7 +1601,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
|
||||
printExpr(subExprsStringStream, *e);
|
||||
});
|
||||
subExprsStringStream.flush();
|
||||
const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});";
|
||||
const char *tensorExprFmt = "\n Value _{0} = {1}({2});";
|
||||
os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName,
|
||||
subExprs);
|
||||
subExprsMap[pTensorExpr] = count;
|
||||
@@ -1613,7 +1613,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
|
||||
using namespace edsc;
|
||||
using namespace intrinsics;
|
||||
auto args = block.getArguments();
|
||||
ValueHandle {1};
|
||||
Value {1};
|
||||
{2}
|
||||
(linalg_yield(ValueRange{ {3} }));
|
||||
})FMT";
|
||||
|
||||
Reference in New Issue
Block a user