[mlir][EDSC] Retire ValueHandle

The EDSC discussion [thread](https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879) points out that ValueHandle has become an unnecessary level of abstraction since MLIR switch from `Value *` to `Value` everywhere.

This revision removes this level of indirection.
This commit is contained in:
Nicolas Vasilache
2020-04-23 11:00:03 -04:00
parent c2fec2fb17
commit 367229e100
27 changed files with 761 additions and 1155 deletions

View File

@@ -531,9 +531,9 @@ llvm::Optional<SmallVector<AffineMap, 8>> batchmatmul::referenceIndexingMaps() {
void batchmatmul::regionBuilder(ArrayRef<BlockArgument> args) {
using namespace edsc;
using namespace intrinsics;
ValueHandle _0(args[0]), _1(args[1]), _2(args[2]);
ValueHandle _4 = std_mulf(_0, _1);
ValueHandle _5 = std_addf(_2, _4);
Value _0(args[0]), _1(args[1]), _2(args[2]);
Value _4 = std_mulf(_0, _1);
Value _5 = std_addf(_2, _4);
(linalg_yield(ValueRange{ _5 }));
}
```

View File

@@ -13,30 +13,17 @@ case, in C++.
supporting a simple declarative API with globally accessible builders. These
declarative builders are available within the lifetime of a `ScopedContext`.
## ValueHandle and IndexHandle
`mlir::edsc::ValueHandle` and `mlir::edsc::IndexHandle` provide typed
abstractions around an `mlir::Value`. These abstractions are "delayed", in the
sense that they allow separating declaration from definition. They may capture
IR snippets, as they are built, for programmatic manipulation. Intuitive
operators are provided to allow concise and idiomatic expressions.
```c++
ValueHandle zero = std_constant_index(0);
IndexHandle i, j, k;
```
## Intrinsics
`mlir::edsc::ValueBuilder` is a generic wrapper for the `mlir::Builder::create`
method that operates on `ValueHandle` objects and return a single ValueHandle.
For instructions that return no values or that return multiple values, the
`mlir::edsc::InstructionBuilder` can be used. Named intrinsics are provided as
`mlir::ValueBuilder` is a generic wrapper for the `mlir::OpBuilder::create`
method that operates on `Value` objects and return a single Value. For
instructions that return no values or that return multiple values, the
`mlir::edsc::OperationBuilder` can be used. Named intrinsics are provided as
syntactic sugar to further reduce boilerplate.
```c++
using load = ValueBuilder<LoadOp>;
using store = InstructionBuilder<StoreOp>;
using store = OperationBuilder<StoreOp>;
```
## LoopBuilder and AffineLoopNestBuilder
@@ -46,14 +33,11 @@ concise and structured loop nests.
```c++
ScopedContext scope(f.get());
ValueHandle i(indexType),
j(indexType),
lb(f->getArgument(0)),
ub(f->getArgument(1));
ValueHandle f7(std_constant_float(llvm::APFloat(7.0f), f32Type)),
f13(std_constant_float(llvm::APFloat(13.0f), f32Type)),
i7(constant_int(7, 32)),
i13(constant_int(13, 32));
Value i, j, lb(f->getArgument(0)), ub(f->getArgument(1));
Value f7(std_constant_float(llvm::APFloat(7.0f), f32Type)),
f13(std_constant_float(llvm::APFloat(13.0f), f32Type)),
i7(constant_int(7, 32)),
i13(constant_int(13, 32));
AffineLoopNestBuilder(&i, lb, ub, 3)([&]{
lb * index_type(3) + ub;
lb + index_type(3);
@@ -84,11 +68,10 @@ def AddOp : Op<"x.add">,
Arguments<(ins Tensor:$A, Tensor:$B)>,
Results<(outs Tensor: $C)> {
code referenceImplementation = [{
auto ivs = makeIndexHandles(view_A.rank());
auto pivs = makePIndexHandles(ivs);
SmallVector<Value, 4> ivs(view_A.rank());
IndexedValue A(arg_A), B(arg_B), C(arg_C);
AffineLoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())(
[&]{
AffineLoopNestBuilder(
ivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())([&]{
C(ivs) = A(ivs) + B(ivs)
});
}];
@@ -124,10 +107,4 @@ Similar APIs are provided to emit the lower-level `loop.for` op with
`LoopNestBuilder`. See the `builder-api-test.cpp` test for more usage examples.
Since the implementation of declarative builders is in C++, it is also available
to program the IR with an embedded-DSL flavor directly integrated in MLIR. We
make use of these properties in the tutorial.
Spoiler: MLIR also provides Python bindings for these builders, and a
full-fledged Python machine learning DSL with automatic differentiation
targeting MLIR was built as an early research collaboration.
to program the IR with an embedded-DSL flavor directly integrated in MLIR.

View File

@@ -23,12 +23,10 @@ namespace mlir {
namespace edsc {
/// Constructs a new AffineForOp and captures the associated induction
/// variable. A ValueHandle pointer is passed as the first argument and is the
/// variable. A Value pointer is passed as the first argument and is the
/// *only* way to capture the loop induction variable.
LoopBuilder makeAffineLoopBuilder(ValueHandle *iv,
ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles,
int64_t step);
LoopBuilder makeAffineLoopBuilder(Value *iv, ArrayRef<Value> lbs,
ArrayRef<Value> ubs, int64_t step);
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
/// explicitly writing all the loops in a nest. This simple functionality is
@@ -58,10 +56,10 @@ public:
/// This entry point accommodates the fact that AffineForOp implicitly uses
/// multiple `lbs` and `ubs` with one single `iv` and `step` to encode `max`
/// and and `min` constraints respectively.
AffineLoopNestBuilder(ValueHandle *iv, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, int64_t step);
AffineLoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
AffineLoopNestBuilder(Value *iv, ArrayRef<Value> lbs, ArrayRef<Value> ubs,
int64_t step);
AffineLoopNestBuilder(MutableArrayRef<Value> ivs, ArrayRef<Value> lbs,
ArrayRef<Value> ubs, ArrayRef<int64_t> steps);
void operator()(function_ref<void(void)> fun = nullptr);
@@ -71,133 +69,134 @@ private:
namespace op {
ValueHandle operator+(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator-(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator*(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator/(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator%(ValueHandle lhs, ValueHandle rhs);
ValueHandle floorDiv(ValueHandle lhs, ValueHandle rhs);
ValueHandle ceilDiv(ValueHandle lhs, ValueHandle rhs);
Value operator+(Value lhs, Value rhs);
Value operator-(Value lhs, Value rhs);
Value operator*(Value lhs, Value rhs);
Value operator/(Value lhs, Value rhs);
Value operator%(Value lhs, Value rhs);
Value floorDiv(Value lhs, Value rhs);
Value ceilDiv(Value lhs, Value rhs);
ValueHandle operator!(ValueHandle value);
ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator||(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator^(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator==(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator<(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator>(ValueHandle lhs, ValueHandle rhs);
ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs);
/// Logical operator overloadings.
Value negate(Value value);
Value operator&&(Value lhs, Value rhs);
Value operator||(Value lhs, Value rhs);
Value operator^(Value lhs, Value rhs);
/// Comparison operator overloadings.
Value eq(Value lhs, Value rhs);
Value ne(Value lhs, Value rhs);
Value operator<(Value lhs, Value rhs);
Value operator<=(Value lhs, Value rhs);
Value operator>(Value lhs, Value rhs);
Value operator>=(Value lhs, Value rhs);
} // namespace op
/// Arithmetic operator overloadings.
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator+(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator+(Value e) {
using op::operator+;
return static_cast<ValueHandle>(*this) + e;
return static_cast<Value>(*this) + e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator-(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator-(Value e) {
using op::operator-;
return static_cast<ValueHandle>(*this) - e;
return static_cast<Value>(*this) - e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator*(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator*(Value e) {
using op::operator*;
return static_cast<ValueHandle>(*this) * e;
return static_cast<Value>(*this) * e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator/(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator/(Value e) {
using op::operator/;
return static_cast<ValueHandle>(*this) / e;
return static_cast<Value>(*this) / e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator%(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator%(Value e) {
using op::operator%;
return static_cast<ValueHandle>(*this) % e;
return static_cast<Value>(*this) % e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator^(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator^(Value e) {
using op::operator^;
return static_cast<ValueHandle>(*this) ^ e;
return static_cast<Value>(*this) ^ e;
}
/// Assignment-arithmetic operator overloadings.
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator+=(ValueHandle e) {
OperationHandle TemplatedIndexedValue<Load, Store>::operator+=(Value e) {
using op::operator+;
return Store(*this + e, getBase(), {indices.begin(), indices.end()});
}
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator-=(ValueHandle e) {
OperationHandle TemplatedIndexedValue<Load, Store>::operator-=(Value e) {
using op::operator-;
return Store(*this - e, getBase(), {indices.begin(), indices.end()});
}
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator*=(ValueHandle e) {
OperationHandle TemplatedIndexedValue<Load, Store>::operator*=(Value e) {
using op::operator*;
return Store(*this * e, getBase(), {indices.begin(), indices.end()});
}
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator/=(ValueHandle e) {
OperationHandle TemplatedIndexedValue<Load, Store>::operator/=(Value e) {
using op::operator/;
return Store(*this / e, getBase(), {indices.begin(), indices.end()});
}
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator%=(ValueHandle e) {
OperationHandle TemplatedIndexedValue<Load, Store>::operator%=(Value e) {
using op::operator%;
return Store(*this % e, getBase(), {indices.begin(), indices.end()});
}
template <typename Load, typename Store>
OperationHandle TemplatedIndexedValue<Load, Store>::operator^=(ValueHandle e) {
OperationHandle TemplatedIndexedValue<Load, Store>::operator^=(Value e) {
using op::operator^;
return Store(*this ^ e, getBase(), {indices.begin(), indices.end()});
}
/// Logical operator overloadings.
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator&&(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator&&(Value e) {
using op::operator&&;
return static_cast<ValueHandle>(*this) && e;
return static_cast<Value>(*this) && e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator||(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator||(Value e) {
using op::operator||;
return static_cast<ValueHandle>(*this) || e;
return static_cast<Value>(*this) || e;
}
/// Comparison operator overloadings.
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator==(ValueHandle e) {
using op::operator==;
return static_cast<ValueHandle>(*this) == e;
Value TemplatedIndexedValue<Load, Store>::eq(Value e) {
return eq(value, e);
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator!=(ValueHandle e) {
using op::operator!=;
return static_cast<ValueHandle>(*this) != e;
Value TemplatedIndexedValue<Load, Store>::ne(Value e) {
return ne(value, e);
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator<(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator<(Value e) {
using op::operator<;
return static_cast<ValueHandle>(*this) < e;
return static_cast<Value>(*this) < e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator<=(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator<=(Value e) {
using op::operator<=;
return static_cast<ValueHandle>(*this) <= e;
return static_cast<Value>(*this) <= e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator>(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator>(Value e) {
using op::operator>;
return static_cast<ValueHandle>(*this) > e;
return static_cast<Value>(*this) > e;
}
template <typename Load, typename Store>
ValueHandle TemplatedIndexedValue<Load, Store>::operator>=(ValueHandle e) {
Value TemplatedIndexedValue<Load, Store>::operator>=(Value e) {
using op::operator>=;
return static_cast<ValueHandle>(*this) >= e;
return static_cast<Value>(*this) >= e;
}
} // namespace edsc

View File

@@ -42,11 +42,10 @@ class ParallelLoopNestBuilder;
class LoopRangeBuilder : public NestedBuilder {
public:
/// Constructs a new loop.for and captures the associated induction
/// variable. A ValueHandle pointer is passed as the first argument and is the
/// variable. A Value pointer is passed as the first argument and is the
/// *only* way to capture the loop induction variable.
LoopRangeBuilder(ValueHandle *iv, ValueHandle range);
LoopRangeBuilder(ValueHandle *iv, Value range);
LoopRangeBuilder(ValueHandle *iv, SubViewOp::Range range);
LoopRangeBuilder(Value *iv, Value range);
LoopRangeBuilder(Value *iv, SubViewOp::Range range);
LoopRangeBuilder(const LoopRangeBuilder &) = delete;
LoopRangeBuilder(LoopRangeBuilder &&) = default;
@@ -57,7 +56,7 @@ public:
/// The only purpose of this operator is to serve as a sequence point so that
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
/// scoped within a LoopRangeBuilder.
ValueHandle operator()(std::function<void(void)> fun = nullptr);
Value operator()(std::function<void(void)> fun = nullptr);
};
/// Helper class to sugar building loop.for loop nests from ranges.
@@ -65,13 +64,10 @@ public:
/// directly. In the current implementation it produces loop.for operations.
class LoopNestRangeBuilder {
public:
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
ArrayRef<edsc::ValueHandle> ranges);
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
ArrayRef<Value> ranges);
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
LoopNestRangeBuilder(MutableArrayRef<Value> ivs, ArrayRef<Value> ranges);
LoopNestRangeBuilder(MutableArrayRef<Value> ivs,
ArrayRef<SubViewOp::Range> ranges);
edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
Value operator()(std::function<void(void)> fun = nullptr);
private:
SmallVector<LoopRangeBuilder, 4> loops;
@@ -81,7 +77,7 @@ private:
/// ranges.
template <typename LoopTy> class GenericLoopNestRangeBuilder {
public:
GenericLoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
GenericLoopNestRangeBuilder(MutableArrayRef<Value> ivs,
ArrayRef<Value> ranges);
void operator()(std::function<void(void)> fun = nullptr) { (*builder)(fun); }
@@ -124,7 +120,6 @@ Operation *makeGenericLinalgOp(
namespace ops {
using edsc::StructuredIndexed;
using edsc::ValueHandle;
//===----------------------------------------------------------------------===//
// EDSC builders for linalg generic operations.
@@ -160,7 +155,7 @@ void macRegionBuilder(ArrayRef<BlockArgument> args);
/// with in-place semantics and parallelism.
/// Unary pointwise operation (with broadcast) entry point.
using UnaryPointwiseOpBuilder = function_ref<Value(ValueHandle)>;
using UnaryPointwiseOpBuilder = function_ref<Value(Value)>;
Operation *linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp,
StructuredIndexed I, StructuredIndexed O);
@@ -171,7 +166,7 @@ Operation *linalg_generic_pointwise_tanh(StructuredIndexed I,
StructuredIndexed O);
/// Binary pointwise operation (with broadcast) entry point.
using BinaryPointwiseOpBuilder = function_ref<Value(ValueHandle, ValueHandle)>;
using BinaryPointwiseOpBuilder = function_ref<Value(Value, Value)>;
Operation *linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp,
StructuredIndexed I1, StructuredIndexed I2,
StructuredIndexed O);
@@ -202,7 +197,7 @@ using MatmulRegionBuilder = function_ref<void(ArrayRef<BlockArgument> args)>;
/// | C(m, n) += A(m, k) * B(k, n)
/// ```
Operation *
linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
linalg_generic_matmul(Value vA, Value vB, Value vC,
MatmulRegionBuilder regionBuilder = macRegionBuilder);
/// Build a linalg.generic, under the current ScopedContext, at the current
@@ -214,7 +209,7 @@ linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
/// ```
/// and returns the tensor `C`.
Operation *
linalg_generic_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC,
MatmulRegionBuilder regionBuilder = mulRegionBuilder);
/// Build a linalg.generic, under the current ScopedContext, at the current
@@ -226,8 +221,7 @@ linalg_generic_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC,
/// ```
/// and returns the tensor `D`.
Operation *
linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
RankedTensorType tD,
linalg_generic_matmul(Value vA, Value vB, Value vC, RankedTensorType tD,
MatmulRegionBuilder regionBuilder = macRegionBuilder);
template <typename Container>
@@ -260,8 +254,8 @@ linalg_generic_matmul(Container values,
/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
///
// TODO(ntv) Extend convolution rank with some template magic.
Operation *linalg_generic_conv_nhwc(ValueHandle vI, ValueHandle vW,
ValueHandle vO, ArrayRef<int> strides = {},
Operation *linalg_generic_conv_nhwc(Value vI, Value vW, Value vO,
ArrayRef<int> strides = {},
ArrayRef<int> dilations = {});
template <typename Container>
@@ -295,8 +289,7 @@ Operation *linalg_generic_conv_nhwc(Container values,
/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
///
// TODO(ntv) Extend convolution rank with some template magic.
Operation *linalg_generic_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW,
ValueHandle vO,
Operation *linalg_generic_dilated_conv_nhwc(Value vI, Value vW, Value vO,
int depth_multiplier = 1,
ArrayRef<int> strides = {},
ArrayRef<int> dilations = {});

View File

@@ -15,16 +15,52 @@
namespace mlir {
namespace edsc {
namespace intrinsics {
template <typename Op, typename... Args>
ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) {
return folder ? ValueHandle(folder->create<Op>(ScopedContext::getBuilder(),
ScopedContext::getLocation(),
args...))
: ValueHandle(ScopedContext::getBuilder().create<Op>(
ScopedContext::getLocation(), args...));
}
template <typename Op>
struct FoldedValueBuilder {
// Builder-based
template <typename... Args>
FoldedValueBuilder(OperationFolder *folder, Args... args) {
value = folder ? folder->create<Op>(ScopedContext::getBuilder(),
ScopedContext::getLocation(), args...)
: ScopedContext::getBuilder().create<Op>(
ScopedContext::getLocation(), args...);
}
operator Value() { return value; }
Value value;
};
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
using folded_std_constant = FoldedValueBuilder<ConstantOp>;
using folded_std_dim = FoldedValueBuilder<DimOp>;
using folded_std_muli = FoldedValueBuilder<MulIOp>;
using folded_std_addi = FoldedValueBuilder<AddIOp>;
using folded_std_addf = FoldedValueBuilder<AddFOp>;
using folded_std_alloc = FoldedValueBuilder<AllocOp>;
using folded_std_constant = FoldedValueBuilder<ConstantOp>;
using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
using folded_std_dim = FoldedValueBuilder<DimOp>;
using folded_std_extract_element = FoldedValueBuilder<ExtractElementOp>;
using folded_std_index_cast = FoldedValueBuilder<IndexCastOp>;
using folded_std_muli = FoldedValueBuilder<MulIOp>;
using folded_std_mulf = FoldedValueBuilder<MulFOp>;
using folded_std_memref_cast = FoldedValueBuilder<MemRefCastOp>;
using folded_std_select = FoldedValueBuilder<SelectOp>;
using folded_std_load = FoldedValueBuilder<LoadOp>;
using folded_std_subi = FoldedValueBuilder<SubIOp>;
using folded_std_sub_view = FoldedValueBuilder<SubViewOp>;
using folded_std_tanh = FoldedValueBuilder<TanhOp>;
using folded_std_tensor_load = FoldedValueBuilder<TensorLoadOp>;
using folded_std_view = FoldedValueBuilder<ViewOp>;
using folded_std_zero_extendi = FoldedValueBuilder<ZeroExtendIOp>;
using folded_std_sign_extendi = FoldedValueBuilder<SignExtendIOp>;
} // namespace intrinsics
} // namespace edsc
} // namespace mlir

View File

@@ -23,27 +23,30 @@ namespace mlir {
namespace edsc {
/// Constructs a new loop::ParallelOp and captures the associated induction
/// variables. An array of ValueHandle pointers is passed as the first
/// variables. An array of Value pointers is passed as the first
/// argument and is the *only* way to capture loop induction variables.
LoopBuilder makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles,
ArrayRef<ValueHandle> steps);
LoopBuilder makeParallelLoopBuilder(MutableArrayRef<Value> ivs,
ArrayRef<Value> lbs, ArrayRef<Value> ubs,
ArrayRef<Value> steps);
/// Constructs a new loop::ForOp and captures the associated induction
/// variable. A ValueHandle pointer is passed as the first argument and is the
/// variable. A Value pointer is passed as the first argument and is the
/// *only* way to capture the loop induction variable.
LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
ValueHandle ubHandle, ValueHandle stepHandle,
ArrayRef<ValueHandle *> iter_args_handles = {},
ValueRange iter_args_init_values = {});
LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step,
MutableArrayRef<Value> iterArgsHandles,
ValueRange iterArgsInitValues);
LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step,
MutableArrayRef<Value> iterArgsHandles,
ValueRange iterArgsInitValues);
inline LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step) {
return makeLoopBuilder(iv, lb, ub, step, MutableArrayRef<Value>{}, {});
}
/// Helper class to sugar building loop.parallel loop nests from lower/upper
/// bounds and step sizes.
class ParallelLoopNestBuilder {
public:
ParallelLoopNestBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
ArrayRef<ValueHandle> steps);
ParallelLoopNestBuilder(MutableArrayRef<Value> ivs, ArrayRef<Value> lbs,
ArrayRef<Value> ubs, ArrayRef<Value> steps);
void operator()(function_ref<void(void)> fun = nullptr);
@@ -56,12 +59,12 @@ private:
/// loop.for.
class LoopNestBuilder {
public:
LoopNestBuilder(ValueHandle *iv, ValueHandle lb, ValueHandle ub,
ValueHandle step,
ArrayRef<ValueHandle *> iter_args_handles = {},
ValueRange iter_args_init_values = {});
LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<ValueHandle> steps);
LoopNestBuilder(Value *iv, Value lb, Value ub, Value step);
LoopNestBuilder(Value *iv, Value lb, Value ub, Value step,
MutableArrayRef<Value> iterArgsHandles,
ValueRange iterArgsInitValues);
LoopNestBuilder(MutableArrayRef<Value> ivs, ArrayRef<Value> lbs,
ArrayRef<Value> ubs, ArrayRef<Value> steps);
Operation::result_range operator()(std::function<void(void)> fun = nullptr);
private:

View File

@@ -20,27 +20,27 @@ namespace edsc {
class BoundsCapture {
public:
unsigned rank() const { return lbs.size(); }
ValueHandle lb(unsigned idx) { return lbs[idx]; }
ValueHandle ub(unsigned idx) { return ubs[idx]; }
Value lb(unsigned idx) { return lbs[idx]; }
Value ub(unsigned idx) { return ubs[idx]; }
int64_t step(unsigned idx) { return steps[idx]; }
std::tuple<ValueHandle, ValueHandle, int64_t> range(unsigned idx) {
std::tuple<Value, Value, int64_t> range(unsigned idx) {
return std::make_tuple(lbs[idx], ubs[idx], steps[idx]);
}
void swapRanges(unsigned i, unsigned j) {
if (i == j)
return;
lbs[i].swap(lbs[j]);
ubs[i].swap(ubs[j]);
std::swap(lbs[i], lbs[j]);
std::swap(ubs[i], ubs[j]);
std::swap(steps[i], steps[j]);
}
ArrayRef<ValueHandle> getLbs() { return lbs; }
ArrayRef<ValueHandle> getUbs() { return ubs; }
ArrayRef<Value> getLbs() { return lbs; }
ArrayRef<Value> getUbs() { return ubs; }
ArrayRef<int64_t> getSteps() { return steps; }
protected:
SmallVector<ValueHandle, 8> lbs;
SmallVector<ValueHandle, 8> ubs;
SmallVector<Value, 8> lbs;
SmallVector<Value, 8> ubs;
SmallVector<int64_t, 8> steps;
};
@@ -58,7 +58,7 @@ public:
unsigned fastestVarying() const { return rank() - 1; }
private:
ValueHandle base;
Value base;
};
/// A VectorBoundsCapture represents the information required to step through a
@@ -72,7 +72,7 @@ public:
VectorBoundsCapture &operator=(const VectorBoundsCapture &) = default;
private:
ValueHandle base;
Value base;
};
} // namespace edsc

View File

@@ -14,40 +14,6 @@
namespace mlir {
namespace edsc {
namespace intrinsics {
namespace folded {
/// Helper variadic abstraction to allow extending to any MLIR op without
/// boilerplate or Tablegen.
/// Arguably a builder is not a ValueHandle but in practice it is only used as
/// an alias to a notional ValueHandle<Op>.
/// Implementing it as a subclass allows it to compose all the way to Value.
/// Without subclassing, implicit conversion to Value would fail when composing
/// in patterns such as: `select(a, b, select(c, d, e))`.
template <typename Op>
struct ValueBuilder : public ValueHandle {
/// Folder-based
template <typename... Args>
ValueBuilder(OperationFolder *folder, Args... args)
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(args)...)) {}
ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs)
: ValueBuilder(ValueBuilder::create<Op>(folder, detail::unpack(vs))) {}
template <typename... Args>
ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs, Args... args)
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(vs),
detail::unpack(args)...)) {}
template <typename T, typename... Args>
ValueBuilder(OperationFolder *folder, T t, ArrayRef<ValueHandle> vs,
Args... args)
: ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(t),
detail::unpack(vs),
detail::unpack(args)...)) {}
template <typename T1, typename T2, typename... Args>
ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
Args... args)
: ValueHandle(ValueHandle::create<Op>(
folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
detail::unpack(args)...)) {}
};
} // namespace folded
using std_addf = ValueBuilder<AddFOp>;
using std_alloc = ValueBuilder<AllocOp>;
@@ -80,7 +46,7 @@ using std_sign_extendi = ValueBuilder<SignExtendIOp>;
///
/// Prerequisites:
/// All Handles have already captured previously constructed IR objects.
OperationHandle std_br(BlockHandle bh, ArrayRef<ValueHandle> operands);
OperationHandle std_br(BlockHandle bh, ArrayRef<Value> operands);
/// Creates a new mlir::Block* and branches to it from the current block.
/// Argument types are specified by `operands`.
@@ -95,8 +61,9 @@ OperationHandle std_br(BlockHandle bh, ArrayRef<ValueHandle> operands);
/// All `operands` have already captured an mlir::Value
/// captures.size() == operands.size()
/// captures and operands are pairwise of the same type.
OperationHandle std_br(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
ArrayRef<ValueHandle> operands);
OperationHandle std_br(BlockHandle *bh, ArrayRef<Type> types,
MutableArrayRef<Value> captures,
ArrayRef<Value> operands);
/// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with
/// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and
@@ -104,10 +71,10 @@ OperationHandle std_br(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
///
/// Prerequisites:
/// All Handles have captured previously constructed IR objects.
OperationHandle std_cond_br(ValueHandle cond, BlockHandle trueBranch,
ArrayRef<ValueHandle> trueOperands,
OperationHandle std_cond_br(Value cond, BlockHandle trueBranch,
ArrayRef<Value> trueOperands,
BlockHandle falseBranch,
ArrayRef<ValueHandle> falseOperands);
ArrayRef<Value> falseOperands);
/// Eagerly creates new mlir::Block* with argument types specified by
/// `trueOperands`/`falseOperands`.
@@ -125,45 +92,17 @@ OperationHandle std_cond_br(ValueHandle cond, BlockHandle trueBranch,
/// `falseCaptures`.size() == `falseOperands`.size()
/// `trueCaptures` and `trueOperands` are pairwise of the same type
/// `falseCaptures` and `falseOperands` are pairwise of the same type.
OperationHandle std_cond_br(ValueHandle cond, BlockHandle *trueBranch,
ArrayRef<ValueHandle *> trueCaptures,
ArrayRef<ValueHandle> trueOperands,
BlockHandle *falseBranch,
ArrayRef<ValueHandle *> falseCaptures,
ArrayRef<ValueHandle> falseOperands);
OperationHandle std_cond_br(Value cond, BlockHandle *trueBranch,
ArrayRef<Type> trueTypes,
MutableArrayRef<Value> trueCaptures,
ArrayRef<Value> trueOperands,
BlockHandle *falseBranch, ArrayRef<Type> falseTypes,
MutableArrayRef<Value> falseCaptures,
ArrayRef<Value> falseOperands);
/// Provide an index notation around sdt_load and std_store.
using StdIndexedValue =
TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>;
using folded_std_constant_float = folded::ValueBuilder<ConstantFloatOp>;
using folded_std_constant_int = folded::ValueBuilder<ConstantIntOp>;
using folded_std_constant = folded::ValueBuilder<ConstantOp>;
using folded_std_dim = folded::ValueBuilder<DimOp>;
using folded_std_muli = folded::ValueBuilder<MulIOp>;
using folded_std_addi = folded::ValueBuilder<AddIOp>;
using folded_std_addf = folded::ValueBuilder<AddFOp>;
using folded_std_alloc = folded::ValueBuilder<AllocOp>;
using folded_std_constant = folded::ValueBuilder<ConstantOp>;
using folded_std_constant_float = folded::ValueBuilder<ConstantFloatOp>;
using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>;
using folded_std_constant_int = folded::ValueBuilder<ConstantIntOp>;
using folded_std_dim = folded::ValueBuilder<DimOp>;
using folded_std_extract_element = folded::ValueBuilder<ExtractElementOp>;
using folded_std_index_cast = folded::ValueBuilder<IndexCastOp>;
using folded_std_muli = folded::ValueBuilder<MulIOp>;
using folded_std_mulf = folded::ValueBuilder<MulFOp>;
using folded_std_memref_cast = folded::ValueBuilder<MemRefCastOp>;
using folded_std_select = folded::ValueBuilder<SelectOp>;
using folded_std_load = folded::ValueBuilder<LoadOp>;
using folded_std_subi = folded::ValueBuilder<SubIOp>;
using folded_std_sub_view = folded::ValueBuilder<SubViewOp>;
using folded_std_tanh = folded::ValueBuilder<TanhOp>;
using folded_std_tensor_load = folded::ValueBuilder<TensorLoadOp>;
using folded_std_view = folded::ValueBuilder<ViewOp>;
using folded_std_zero_extendi = folded::ValueBuilder<ZeroExtendIOp>;
using folded_std_sign_extendi = folded::ValueBuilder<SignExtendIOp>;
} // namespace intrinsics
} // namespace edsc
} // namespace mlir

View File

@@ -18,6 +18,7 @@ using vector_broadcast = ValueBuilder<vector::BroadcastOp>;
using vector_contract = ValueBuilder<vector::ContractionOp>;
using vector_matmul = ValueBuilder<vector::MatmulOp>;
using vector_print = OperationBuilder<vector::PrintOp>;
using vector_type_cast = ValueBuilder<vector::TypeCastOp>;
} // namespace intrinsics
} // namespace edsc

View File

@@ -24,9 +24,7 @@ class OperationFolder;
namespace edsc {
class BlockHandle;
class CapturableHandle;
class NestedBuilder;
class ValueHandle;
/// Helper class to transparently handle builder insertion points by RAII.
/// As its name indicates, a ScopedContext is means to be used locally in a
@@ -70,10 +68,23 @@ private:
/// Defensively keeps track of the current NestedBuilder to ensure proper
/// scoping usage.
NestedBuilder *nestedBuilder;
};
// TODO: Implement scoping of ValueHandles. To do this we need a proper data
// structure to hold ValueHandle objects. We can emulate one but there should
// already be something available in LLVM for this purpose.
template <typename Op>
struct ValueBuilder {
// Builder-based
template <typename... Args>
ValueBuilder(Args... args) {
Operation *op = ScopedContext::getBuilder()
.create<Op>(ScopedContext::getLocation(), args...)
.getOperation();
if (op->getNumResults() != 1)
llvm_unreachable("unsupported operation, use OperationBuilder instead");
value = op->getResult(0);
}
operator Value() { return value; }
Value value;
};
/// A NestedBuilder is a scoping abstraction to create an idiomatic syntax
@@ -82,8 +93,7 @@ private:
/// exists between object construction and method invocation on said object (in
/// our case, the call to `operator()`).
/// This ordering allows implementing an abstraction that decouples definition
/// from declaration (in a PL sense) on placeholders of type ValueHandle and
/// BlockHandle.
/// from declaration (in a PL sense) on placeholders.
class NestedBuilder {
protected:
NestedBuilder() = default;
@@ -158,19 +168,17 @@ public:
private:
LoopBuilder() = default;
friend LoopBuilder makeAffineLoopBuilder(ValueHandle *iv,
ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles,
friend LoopBuilder makeAffineLoopBuilder(Value *iv, ArrayRef<Value> lbHandles,
ArrayRef<Value> ubHandles,
int64_t step);
friend LoopBuilder makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles,
ArrayRef<ValueHandle> steps);
friend LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
ValueHandle ubHandle,
ValueHandle stepHandle,
ArrayRef<ValueHandle *> iter_args_handles,
ValueRange iter_args_init_values);
friend LoopBuilder makeParallelLoopBuilder(MutableArrayRef<Value> ivs,
ArrayRef<Value> lbHandles,
ArrayRef<Value> ubHandles,
ArrayRef<Value> steps);
friend LoopBuilder makeLoopBuilder(Value *iv, Value lbHandle, Value ubHandle,
Value stepHandle,
MutableArrayRef<Value> iterArgsHandles,
ValueRange iterArgsInitValues);
Operation *op;
};
@@ -194,9 +202,11 @@ public:
/// Enters the new mlir::Block* and sets the insertion point to its end.
///
/// Prerequisites:
/// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are
/// The Value `args` are typed delayed Values; i.e. they are
/// not yet bound to mlir::Value.
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
BlockBuilder(BlockHandle *bh) : BlockBuilder(bh, {}, {}) {}
BlockBuilder(BlockHandle *bh, ArrayRef<Type> types,
MutableArrayRef<Value> args);
/// Constructs a new mlir::Block with argument types derived from `args` and
/// appends it as the last block in the region.
@@ -204,9 +214,10 @@ public:
/// Enters the new mlir::Block* and sets the insertion point to its end.
///
/// Prerequisites:
/// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are
/// The Value `args` are typed delayed Values; i.e. they are
/// not yet bound to mlir::Value.
BlockBuilder(BlockHandle *bh, Region &region, ArrayRef<ValueHandle *> args);
BlockBuilder(BlockHandle *bh, Region &region, ArrayRef<Type> types,
MutableArrayRef<Value> args);
/// The only purpose of this operator is to serve as a sequence point so that
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
@@ -218,120 +229,18 @@ private:
BlockBuilder &operator=(BlockBuilder &other) = delete;
};
/// Base class for ValueHandle, OperationHandle and BlockHandle.
/// Base class for Value, OperationHandle and BlockHandle.
/// Not meant to be used outside of these classes.
class CapturableHandle {
protected:
CapturableHandle() = default;
};
/// ValueHandle implements a (potentially "delayed") typed Value abstraction.
/// ValueHandle should be captured by pointer but otherwise passed by Value
/// everywhere.
/// A ValueHandle can have 3 states:
/// 1. null state (empty type and empty value), in which case it does not hold
/// a value and must never hold a Value (now or in the future). This is
/// used for MLIR operations with zero returns as well as the result of
/// calling a NestedBuilder::operator(). In both cases the objective is to
/// have an object that can be inserted in an ArrayRef<ValueHandle> to
/// implement nesting;
/// 2. delayed state (empty value), in which case it represents an eagerly
/// typed "delayed" value that can be hold a Value in the future;
/// 3. constructed state,in which case it holds a Value.
///
/// A ValueHandle is meant to capture a single Value and should be used for
/// operations that have a single result. For convenience of use, we also
/// include AffineForOp in this category although it does not return a value.
/// In the case of AffineForOp, the captured Value is the loop induction
/// variable.
class ValueHandle : public CapturableHandle {
public:
/// A ValueHandle in a null state can never be captured;
static ValueHandle null() { return ValueHandle(); }
/// A ValueHandle that is constructed from a Type represents a typed "delayed"
/// Value. A delayed Value can only capture Values of the specified type.
/// Such a delayed value represents the declaration (in the PL sense) of a
/// placeholder for an mlir::Value that will be constructed and captured at
/// some later point in the program.
explicit ValueHandle(Type t) : t(t), v(nullptr) {}
/// A ValueHandle that is constructed from an mlir::Value is an "eager"
/// Value. An eager Value represents both the declaration and the definition
/// (in the PL sense) of a placeholder for an mlir::Value that has already
/// been constructed in the past and that is captured "now" in the program.
explicit ValueHandle(Value v) : t(v.getType()), v(v) {}
/// ValueHandle is a value type, use the default copy constructor.
ValueHandle(const ValueHandle &other) = default;
/// ValueHandle is a value type, the assignment operator typechecks before
/// assigning.
ValueHandle &operator=(const ValueHandle &other);
/// Provide a swap operator.
void swap(ValueHandle &other) {
if (this == &other)
return;
std::swap(t, other.t);
std::swap(v, other.v);
}
/// Implicit conversion useful for automatic conversion to Container<Value>.
operator Value() const { return getValue(); }
operator Type() const { return getType(); }
operator bool() const { return hasValue(); }
/// Generic mlir::Op create. This is the key to being extensible to the whole
/// of MLIR without duplicating the type system or the op definitions.
template <typename Op, typename... Args>
static ValueHandle create(Args... args);
/// Generic mlir::Op create. This is the key to being extensible to the whole
/// of MLIR without duplicating the type system or the op definitions.
/// When non-null, the optional pointer `folder` is used to call into the
/// `createAndFold` builder method. If `folder` is null, the regular `create`
/// method is called.
template <typename Op, typename... Args>
static ValueHandle create(OperationFolder *folder, Args... args);
/// Generic create for a named operation producing a single value.
static ValueHandle create(StringRef name, ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes = {});
bool hasValue() const { return v != nullptr; }
Value getValue() const {
assert(hasValue() && "Unexpected null value;");
return v;
}
bool hasType() const { return t != Type(); }
Type getType() const { return t; }
Operation *getOperation() const {
if (!v)
return nullptr;
return v.getDefiningOp();
}
// Return a vector of fresh ValueHandles that have not captured.
static SmallVector<ValueHandle, 8> makeIndexHandles(unsigned count) {
auto indexType = IndexType::get(ScopedContext::getContext());
return SmallVector<ValueHandle, 8>(count, ValueHandle(indexType));
}
protected:
ValueHandle() : t(), v(nullptr) {}
Type t;
Value v;
};
/// An OperationHandle can be used in lieu of ValueHandle to capture the
/// An OperationHandle can be used in lieu of Value to capture the
/// operation in cases when one does not care about, or cannot extract, a
/// unique Value from the operation.
/// This can be used for capturing zero result operations as well as
/// multi-result operations that are not supported by ValueHandle.
/// multi-result operations that are not supported by Value.
/// We do not distinguish further between zero and multi-result operations at
/// this time.
struct OperationHandle : public CapturableHandle {
@@ -349,7 +258,7 @@ struct OperationHandle : public CapturableHandle {
static Op createOp(Args... args);
/// Generic create for a named operation.
static OperationHandle create(StringRef name, ArrayRef<ValueHandle> operands,
static OperationHandle create(StringRef name, ArrayRef<Value> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes = {});
@@ -360,23 +269,6 @@ private:
Operation *op;
};
/// Simple wrapper to build a generic operation without successor blocks.
template <typename HandleType>
struct CustomOperation {
CustomOperation(StringRef name) : name(name) {
static_assert(std::is_same<HandleType, ValueHandle>() ||
std::is_same<HandleType, OperationHandle>(),
"Only CustomOperation<ValueHandle> or "
"CustomOperation<OperationHandle> can be constructed.");
}
HandleType operator()(ArrayRef<ValueHandle> operands = {},
ArrayRef<Type> resultTypes = {},
ArrayRef<NamedAttribute> attributes = {}) {
return HandleType::create(name, operands, resultTypes, attributes);
}
std::string name;
};
/// A BlockHandle represents a (potentially "delayed") Block abstraction.
/// This extra abstraction is necessary because an mlir::Block is not an
/// mlir::Value.
@@ -427,32 +319,45 @@ private:
/// C(buffer_value_or_tensor_type);
/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
/// ```
struct StructuredIndexed : public ValueHandle {
StructuredIndexed(Type type) : ValueHandle(type) {}
StructuredIndexed(Value value) : ValueHandle(value) {}
StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {}
struct StructuredIndexed {
StructuredIndexed(Value v) : value(v) {}
StructuredIndexed(Type t) : type(t) {}
StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
return this->hasValue() ? StructuredIndexed(this->getValue(), indexings)
: StructuredIndexed(this->getType(), indexings);
return value ? StructuredIndexed(value, indexings)
: StructuredIndexed(type, indexings);
}
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
: ValueHandle(t), exprs(indexings.begin(), indexings.end()) {
assert(t.isa<RankedTensorType>() && "RankedTensor expected");
}
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
: ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
: value(v), exprs(indexings.begin(), indexings.end()) {
assert((v.getType().isa<MemRefType>() ||
v.getType().isa<RankedTensorType>() ||
v.getType().isa<VectorType>()) &&
"MemRef, RankedTensor or Vector expected");
}
StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
: ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
: type(t), exprs(indexings.begin(), indexings.end()) {
assert((t.isa<MemRefType>() || t.isa<RankedTensorType>() ||
t.isa<VectorType>()) &&
"MemRef, RankedTensor or Vector expected");
}
ArrayRef<AffineExpr> getExprs() { return exprs; }
bool hasValue() const { return value; }
Value getValue() const {
assert(value && "StructuredIndexed Value not set.");
return value;
}
Type getType() const {
assert((value || type) && "StructuredIndexed Value and Type not set.");
return value ? value.getType() : type;
}
ArrayRef<AffineExpr> getExprs() const { return exprs; }
operator Value() const { return getValue(); }
operator Type() const { return getType(); }
private:
// Only one of Value or type may be set.
Type type;
Value value;
SmallVector<AffineExpr, 4> exprs;
};
@@ -472,179 +377,139 @@ Op OperationHandle::createOp(Args... args) {
.getOperation());
}
template <typename Op, typename... Args>
ValueHandle ValueHandle::create(Args... args) {
Operation *op = ScopedContext::getBuilder()
.create<Op>(ScopedContext::getLocation(), args...)
.getOperation();
if (op->getNumResults() == 1)
return ValueHandle(op->getResult(0));
llvm_unreachable("unsupported operation, use an OperationHandle instead");
}
/// Entry point to build multiple ValueHandle from a `Container` of Value or
/// Type.
template <typename Container>
inline SmallVector<ValueHandle, 8> makeValueHandles(Container values) {
SmallVector<ValueHandle, 8> res;
res.reserve(values.size());
for (auto v : values)
res.push_back(ValueHandle(v));
return res;
}
/// A TemplatedIndexedValue brings an index notation over the template Load and
/// Store parameters. Assigning to an IndexedValue emits an actual `Store`
/// operation, while converting an IndexedValue to a ValueHandle emits an actual
/// operation, while converting an IndexedValue to a Value emits an actual
/// `Load` operation.
template <typename Load, typename Store>
class TemplatedIndexedValue {
public:
explicit TemplatedIndexedValue(Type t) : base(t) {}
explicit TemplatedIndexedValue(Value v)
: TemplatedIndexedValue(ValueHandle(v)) {}
explicit TemplatedIndexedValue(ValueHandle v) : base(v) {}
explicit TemplatedIndexedValue(Value v) : value(v) {}
TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default;
TemplatedIndexedValue operator()() { return *this; }
/// Returns a new `TemplatedIndexedValue`.
TemplatedIndexedValue operator()(ValueHandle index) {
TemplatedIndexedValue res(base);
TemplatedIndexedValue operator()(Value index) {
TemplatedIndexedValue res(value);
res.indices.push_back(index);
return res;
}
template <typename... Args>
TemplatedIndexedValue operator()(ValueHandle index, Args... indices) {
return TemplatedIndexedValue(base, index).append(indices...);
TemplatedIndexedValue operator()(Value index, Args... indices) {
return TemplatedIndexedValue(value, index).append(indices...);
}
TemplatedIndexedValue operator()(ArrayRef<ValueHandle> indices) {
return TemplatedIndexedValue(base, indices);
TemplatedIndexedValue operator()(ArrayRef<Value> indices) {
return TemplatedIndexedValue(value, indices);
}
/// Emits a `store`.
OperationHandle operator=(const TemplatedIndexedValue &rhs) {
ValueHandle rrhs(rhs);
return Store(rrhs, getBase(), {indices.begin(), indices.end()});
}
OperationHandle operator=(ValueHandle rhs) {
return Store(rhs, getBase(), {indices.begin(), indices.end()});
}
/// Emits a `load` when converting to a ValueHandle.
operator ValueHandle() const {
return Load(getBase(), {indices.begin(), indices.end()});
return Store(rhs, value, indices);
}
OperationHandle operator=(Value rhs) { return Store(rhs, value, indices); }
/// Emits a `load` when converting to a Value.
Value operator*(void)const {
return Load(getBase(), {indices.begin(), indices.end()}).getValue();
}
operator Value() const { return Load(value, indices); }
ValueHandle getBase() const { return base; }
Value getBase() const { return value; }
/// Arithmetic operator overloadings.
ValueHandle operator+(ValueHandle e);
ValueHandle operator-(ValueHandle e);
ValueHandle operator*(ValueHandle e);
ValueHandle operator/(ValueHandle e);
ValueHandle operator%(ValueHandle e);
ValueHandle operator^(ValueHandle e);
ValueHandle operator+(TemplatedIndexedValue e) {
return *this + static_cast<ValueHandle>(e);
Value operator+(Value e);
Value operator-(Value e);
Value operator*(Value e);
Value operator/(Value e);
Value operator%(Value e);
Value operator^(Value e);
Value operator+(TemplatedIndexedValue e) {
return *this + static_cast<Value>(e);
}
ValueHandle operator-(TemplatedIndexedValue e) {
return *this - static_cast<ValueHandle>(e);
Value operator-(TemplatedIndexedValue e) {
return *this - static_cast<Value>(e);
}
ValueHandle operator*(TemplatedIndexedValue e) {
return *this * static_cast<ValueHandle>(e);
Value operator*(TemplatedIndexedValue e) {
return *this * static_cast<Value>(e);
}
ValueHandle operator/(TemplatedIndexedValue e) {
return *this / static_cast<ValueHandle>(e);
Value operator/(TemplatedIndexedValue e) {
return *this / static_cast<Value>(e);
}
ValueHandle operator%(TemplatedIndexedValue e) {
return *this % static_cast<ValueHandle>(e);
Value operator%(TemplatedIndexedValue e) {
return *this % static_cast<Value>(e);
}
ValueHandle operator^(TemplatedIndexedValue e) {
return *this ^ static_cast<ValueHandle>(e);
Value operator^(TemplatedIndexedValue e) {
return *this ^ static_cast<Value>(e);
}
/// Assignment-arithmetic operator overloadings.
OperationHandle operator+=(ValueHandle e);
OperationHandle operator-=(ValueHandle e);
OperationHandle operator*=(ValueHandle e);
OperationHandle operator/=(ValueHandle e);
OperationHandle operator%=(ValueHandle e);
OperationHandle operator^=(ValueHandle e);
OperationHandle operator+=(Value e);
OperationHandle operator-=(Value e);
OperationHandle operator*=(Value e);
OperationHandle operator/=(Value e);
OperationHandle operator%=(Value e);
OperationHandle operator^=(Value e);
OperationHandle operator+=(TemplatedIndexedValue e) {
return this->operator+=(static_cast<ValueHandle>(e));
return this->operator+=(static_cast<Value>(e));
}
OperationHandle operator-=(TemplatedIndexedValue e) {
return this->operator-=(static_cast<ValueHandle>(e));
return this->operator-=(static_cast<Value>(e));
}
OperationHandle operator*=(TemplatedIndexedValue e) {
return this->operator*=(static_cast<ValueHandle>(e));
return this->operator*=(static_cast<Value>(e));
}
OperationHandle operator/=(TemplatedIndexedValue e) {
return this->operator/=(static_cast<ValueHandle>(e));
return this->operator/=(static_cast<Value>(e));
}
OperationHandle operator%=(TemplatedIndexedValue e) {
return this->operator%=(static_cast<ValueHandle>(e));
return this->operator%=(static_cast<Value>(e));
}
OperationHandle operator^=(TemplatedIndexedValue e) {
return this->operator^=(static_cast<ValueHandle>(e));
return this->operator^=(static_cast<Value>(e));
}
/// Logical operator overloadings.
ValueHandle operator&&(ValueHandle e);
ValueHandle operator||(ValueHandle e);
ValueHandle operator&&(TemplatedIndexedValue e) {
return *this && static_cast<ValueHandle>(e);
Value operator&&(Value e);
Value operator||(Value e);
Value operator&&(TemplatedIndexedValue e) {
return *this && static_cast<Value>(e);
}
ValueHandle operator||(TemplatedIndexedValue e) {
return *this || static_cast<ValueHandle>(e);
Value operator||(TemplatedIndexedValue e) {
return *this || static_cast<Value>(e);
}
/// Comparison operator overloadings.
ValueHandle operator==(ValueHandle e);
ValueHandle operator!=(ValueHandle e);
ValueHandle operator<(ValueHandle e);
ValueHandle operator<=(ValueHandle e);
ValueHandle operator>(ValueHandle e);
ValueHandle operator>=(ValueHandle e);
ValueHandle operator==(TemplatedIndexedValue e) {
return *this == static_cast<ValueHandle>(e);
Value eq(Value e);
Value ne(Value e);
Value operator<(Value e);
Value operator<=(Value e);
Value operator>(Value e);
Value operator>=(Value e);
Value operator<(TemplatedIndexedValue e) {
return *this < static_cast<Value>(e);
}
ValueHandle operator!=(TemplatedIndexedValue e) {
return *this != static_cast<ValueHandle>(e);
Value operator<=(TemplatedIndexedValue e) {
return *this <= static_cast<Value>(e);
}
ValueHandle operator<(TemplatedIndexedValue e) {
return *this < static_cast<ValueHandle>(e);
Value operator>(TemplatedIndexedValue e) {
return *this > static_cast<Value>(e);
}
ValueHandle operator<=(TemplatedIndexedValue e) {
return *this <= static_cast<ValueHandle>(e);
}
ValueHandle operator>(TemplatedIndexedValue e) {
return *this > static_cast<ValueHandle>(e);
}
ValueHandle operator>=(TemplatedIndexedValue e) {
return *this >= static_cast<ValueHandle>(e);
Value operator>=(TemplatedIndexedValue e) {
return *this >= static_cast<Value>(e);
}
private:
TemplatedIndexedValue(ValueHandle base, ArrayRef<ValueHandle> indices)
: base(base), indices(indices.begin(), indices.end()) {}
TemplatedIndexedValue(Value value, ArrayRef<Value> indices)
: value(value), indices(indices.begin(), indices.end()) {}
TemplatedIndexedValue &append() { return *this; }
template <typename T, typename... Args>
TemplatedIndexedValue &append(T index, Args... indices) {
this->indices.push_back(static_cast<ValueHandle>(index));
this->indices.push_back(static_cast<Value>(index));
append(indices...);
return *this;
}
ValueHandle base;
SmallVector<ValueHandle, 8> indices;
Value value;
SmallVector<Value, 8> indices;
};
} // namespace edsc

View File

@@ -25,99 +25,27 @@ class Type;
namespace edsc {
/// Entry point to build multiple ValueHandle* from a mutable list `ivs`.
inline SmallVector<ValueHandle *, 8>
makeHandlePointers(MutableArrayRef<ValueHandle> ivs) {
SmallVector<ValueHandle *, 8> pivs;
pivs.reserve(ivs.size());
for (auto &iv : ivs)
pivs.push_back(&iv);
return pivs;
}
/// Provides a set of first class intrinsics.
/// In the future, most of intrinsics related to Operation that don't contain
/// other operations should be Tablegen'd.
namespace intrinsics {
namespace detail {
/// Helper structure to be used with ValueBuilder / OperationBuilder.
/// It serves the purpose of removing boilerplate specialization for the sole
/// purpose of implicitly converting ArrayRef<ValueHandle> -> ArrayRef<Value>.
class ValueHandleArray {
public:
ValueHandleArray(ArrayRef<ValueHandle> vals) {
values.append(vals.begin(), vals.end());
}
operator ArrayRef<Value>() { return values; }
private:
ValueHandleArray() = default;
SmallVector<Value, 8> values;
};
template <typename T>
inline T unpack(T value) {
return value;
}
inline detail::ValueHandleArray unpack(ArrayRef<ValueHandle> values) {
return detail::ValueHandleArray(values);
}
} // namespace detail
/// Helper variadic abstraction to allow extending to any MLIR op without
/// boilerplate or Tablegen.
/// Arguably a builder is not a ValueHandle but in practice it is only used as
/// an alias to a notional ValueHandle<Op>.
/// Implementing it as a subclass allows it to compose all the way to Value.
/// Without subclassing, implicit conversion to Value would fail when composing
/// in patterns such as: `select(a, b, select(c, d, e))`.
template <typename Op>
struct ValueBuilder : public ValueHandle {
// Builder-based
template <typename... Args>
ValueBuilder(Args... args)
: ValueHandle(ValueHandle::create<Op>(detail::unpack(args)...)) {}
ValueBuilder(ArrayRef<ValueHandle> vs)
: ValueBuilder(ValueBuilder::create<Op>(detail::unpack(vs))) {}
template <typename... Args>
ValueBuilder(ArrayRef<ValueHandle> vs, Args... args)
: ValueHandle(ValueHandle::create<Op>(detail::unpack(vs),
detail::unpack(args)...)) {}
template <typename T, typename... Args>
ValueBuilder(T t, ArrayRef<ValueHandle> vs, Args... args)
: ValueHandle(ValueHandle::create<Op>(
detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {}
template <typename T1, typename T2, typename... Args>
ValueBuilder(T1 t1, T2 t2, ArrayRef<ValueHandle> vs, Args... args)
: ValueHandle(ValueHandle::create<Op>(
detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
detail::unpack(args)...)) {}
ValueBuilder() : ValueHandle(ValueHandle::create<Op>()) {}
};
template <typename Op>
struct OperationBuilder : public OperationHandle {
template <typename... Args>
OperationBuilder(Args... args)
: OperationHandle(OperationHandle::create<Op>(detail::unpack(args)...)) {}
OperationBuilder(ArrayRef<ValueHandle> vs)
: OperationHandle(OperationHandle::create<Op>(detail::unpack(vs))) {}
: OperationHandle(OperationHandle::create<Op>(args...)) {}
OperationBuilder(ArrayRef<Value> vs)
: OperationHandle(OperationHandle::create<Op>(vs)) {}
template <typename... Args>
OperationBuilder(ArrayRef<ValueHandle> vs, Args... args)
: OperationHandle(OperationHandle::create<Op>(detail::unpack(vs),
detail::unpack(args)...)) {}
OperationBuilder(ArrayRef<Value> vs, Args... args)
: OperationHandle(OperationHandle::create<Op>(vs, args...)) {}
template <typename T, typename... Args>
OperationBuilder(T t, ArrayRef<ValueHandle> vs, Args... args)
: OperationHandle(OperationHandle::create<Op>(
detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {}
OperationBuilder(T t, ArrayRef<Value> vs, Args... args)
: OperationHandle(OperationHandle::create<Op>(t, vs, args...)) {}
template <typename T1, typename T2, typename... Args>
OperationBuilder(T1 t1, T2 t2, ArrayRef<ValueHandle> vs, Args... args)
: OperationHandle(OperationHandle::create<Op>(
detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
detail::unpack(args)...)) {}
OperationBuilder(T1 t1, T2 t2, ArrayRef<Value> vs, Args... args)
: OperationHandle(OperationHandle::create<Op>(t1, t2, vs, args...)) {}
OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {}
};

View File

@@ -16,6 +16,7 @@
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -38,9 +39,7 @@ using vector::TransferWriteOp;
/// `pivs` and `vectorBoundsCapture` are swapped so that the invocation of
/// LoopNestBuilder captures it in the innermost loop.
template <typename TransferOpTy>
static void coalesceCopy(TransferOpTy transfer,
SmallVectorImpl<ValueHandle *> *pivs,
VectorBoundsCapture *vectorBoundsCapture) {
static int computeCoalescedIndex(TransferOpTy transfer) {
// rank of the remote memory access, coalescing behavior occurs on the
// innermost memory dimension.
auto remoteRank = transfer.getMemRefType().getRank();
@@ -62,24 +61,19 @@ static void coalesceCopy(TransferOpTy transfer,
coalescedIdx = en.index();
}
}
if (coalescedIdx >= 0) {
std::swap(pivs->back(), (*pivs)[coalescedIdx]);
vectorBoundsCapture->swapRanges(pivs->size() - 1, coalescedIdx);
}
return coalescedIdx;
}
/// Emits remote memory accesses that are clipped to the boundaries of the
/// MemRef.
template <typename TransferOpTy>
static SmallVector<ValueHandle, 8> clip(TransferOpTy transfer,
MemRefBoundsCapture &bounds,
ArrayRef<ValueHandle> ivs) {
static SmallVector<Value, 8>
clip(TransferOpTy transfer, MemRefBoundsCapture &bounds, ArrayRef<Value> ivs) {
using namespace mlir::edsc;
ValueHandle zero(std_constant_index(0)), one(std_constant_index(1));
SmallVector<ValueHandle, 8> memRefAccess(transfer.indices());
auto clippedScalarAccessExprs =
ValueHandle::makeIndexHandles(memRefAccess.size());
Value zero(std_constant_index(0)), one(std_constant_index(1));
SmallVector<Value, 8> memRefAccess(transfer.indices());
SmallVector<Value, 8> clippedScalarAccessExprs(memRefAccess.size());
// Indices accessing to remote memory are clipped and their expressions are
// returned in clippedScalarAccessExprs.
for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size();
@@ -126,8 +120,6 @@ static SmallVector<ValueHandle, 8> clip(TransferOpTy transfer,
namespace {
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
/// Implements lowering of TransferReadOp and TransferWriteOp to a
/// proper abstraction for the hardware.
///
@@ -257,31 +249,36 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
StdIndexedValue remote(transfer.memref());
MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
auto ivs = ValueHandle::makeIndexHandles(vectorBoundsCapture.rank());
SmallVector<ValueHandle *, 8> pivs =
makeHandlePointers(MutableArrayRef<ValueHandle>(ivs));
coalesceCopy(transfer, &pivs, &vectorBoundsCapture);
int coalescedIdx = computeCoalescedIndex(transfer);
// Swap the vectorBoundsCapture which will reorder loop bounds.
if (coalescedIdx >= 0)
vectorBoundsCapture.swapRanges(vectorBoundsCapture.rank() - 1,
coalescedIdx);
auto lbs = vectorBoundsCapture.getLbs();
auto ubs = vectorBoundsCapture.getUbs();
SmallVector<ValueHandle, 8> steps;
SmallVector<Value, 8> steps;
steps.reserve(vectorBoundsCapture.getSteps().size());
for (auto step : vectorBoundsCapture.getSteps())
steps.push_back(std_constant_index(step));
// 2. Emit alloc-copy-load-dealloc.
ValueHandle tmp = std_alloc(tmpMemRefType(transfer));
Value tmp = std_alloc(tmpMemRefType(transfer));
StdIndexedValue local(tmp);
ValueHandle vec = vector_type_cast(tmp);
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
Value vec = vector_type_cast(tmp);
SmallVector<Value, 8> ivs(lbs.size());
LoopNestBuilder(ivs, lbs, ubs, steps)([&] {
// Swap the ivs which will reorder memory accesses.
if (coalescedIdx >= 0)
std::swap(ivs.back(), ivs[coalescedIdx]);
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
local(ivs) = remote(clip(transfer, memRefBoundsCapture, ivs));
});
ValueHandle vectorValue = std_load(vec);
Value vectorValue = std_load(vec);
(std_dealloc(tmp)); // vexing parse
// 3. Propagate.
rewriter.replaceOp(op, vectorValue.getValue());
rewriter.replaceOp(op, vectorValue);
return success();
}
@@ -314,26 +311,31 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
ScopedContext scope(rewriter, transfer.getLoc());
StdIndexedValue remote(transfer.memref());
MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
ValueHandle vectorValue(transfer.vector());
Value vectorValue(transfer.vector());
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
auto ivs = ValueHandle::makeIndexHandles(vectorBoundsCapture.rank());
SmallVector<ValueHandle *, 8> pivs =
makeHandlePointers(MutableArrayRef<ValueHandle>(ivs));
coalesceCopy(transfer, &pivs, &vectorBoundsCapture);
int coalescedIdx = computeCoalescedIndex(transfer);
// Swap the vectorBoundsCapture which will reorder loop bounds.
if (coalescedIdx >= 0)
vectorBoundsCapture.swapRanges(vectorBoundsCapture.rank() - 1,
coalescedIdx);
auto lbs = vectorBoundsCapture.getLbs();
auto ubs = vectorBoundsCapture.getUbs();
SmallVector<ValueHandle, 8> steps;
SmallVector<Value, 8> steps;
steps.reserve(vectorBoundsCapture.getSteps().size());
for (auto step : vectorBoundsCapture.getSteps())
steps.push_back(std_constant_index(step));
// 2. Emit alloc-store-copy-dealloc.
ValueHandle tmp = std_alloc(tmpMemRefType(transfer));
Value tmp = std_alloc(tmpMemRefType(transfer));
StdIndexedValue local(tmp);
ValueHandle vec = vector_type_cast(tmp);
Value vec = vector_type_cast(tmp);
std_store(vectorValue, vec);
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
SmallVector<Value, 8> ivs(lbs.size());
LoopNestBuilder(ivs, lbs, ubs, steps)([&] {
// Swap the ivs which will reorder memory accesses.
if (coalescedIdx >= 0)
std::swap(ivs.back(), ivs[coalescedIdx]);
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
remote(clip(transfer, memRefBoundsCapture, ivs)) = local(ivs);
});

View File

@@ -14,65 +14,61 @@
using namespace mlir;
using namespace mlir::edsc;
static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs,
int64_t step) {
static Optional<Value> emitStaticFor(ArrayRef<Value> lbs, ArrayRef<Value> ubs,
int64_t step) {
if (lbs.size() != 1 || ubs.size() != 1)
return Optional<ValueHandle>();
return Optional<Value>();
auto *lbDef = lbs.front().getValue().getDefiningOp();
auto *ubDef = ubs.front().getValue().getDefiningOp();
auto *lbDef = lbs.front().getDefiningOp();
auto *ubDef = ubs.front().getDefiningOp();
if (!lbDef || !ubDef)
return Optional<ValueHandle>();
return Optional<Value>();
auto lbConst = dyn_cast<ConstantIndexOp>(lbDef);
auto ubConst = dyn_cast<ConstantIndexOp>(ubDef);
if (!lbConst || !ubConst)
return Optional<ValueHandle>();
return ValueHandle(ScopedContext::getBuilder()
.create<AffineForOp>(ScopedContext::getLocation(),
lbConst.getValue(),
ubConst.getValue(), step)
.getInductionVar());
return Optional<Value>();
return ScopedContext::getBuilder()
.create<AffineForOp>(ScopedContext::getLocation(), lbConst.getValue(),
ubConst.getValue(), step)
.getInductionVar();
}
LoopBuilder mlir::edsc::makeAffineLoopBuilder(ValueHandle *iv,
ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles,
LoopBuilder mlir::edsc::makeAffineLoopBuilder(Value *iv, ArrayRef<Value> lbs,
ArrayRef<Value> ubs,
int64_t step) {
mlir::edsc::LoopBuilder result;
if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) {
*iv = staticFor.getValue();
if (auto staticForIv = emitStaticFor(lbs, ubs, step)) {
*iv = staticForIv.getValue();
} else {
SmallVector<Value, 4> lbs(lbHandles.begin(), lbHandles.end());
SmallVector<Value, 4> ubs(ubHandles.begin(), ubHandles.end());
auto b = ScopedContext::getBuilder();
*iv = ValueHandle(
b.create<AffineForOp>(ScopedContext::getLocation(), lbs,
b.getMultiDimIdentityMap(lbs.size()), ubs,
b.getMultiDimIdentityMap(ubs.size()), step)
.getInductionVar());
*iv =
Value(b.create<AffineForOp>(ScopedContext::getLocation(), lbs,
b.getMultiDimIdentityMap(lbs.size()), ubs,
b.getMultiDimIdentityMap(ubs.size()), step)
.getInductionVar());
}
auto *body = getForInductionVarOwner(iv->getValue()).getBody();
auto *body = getForInductionVarOwner(*iv).getBody();
result.enter(body, /*prev=*/1);
return result;
}
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
ValueHandle *iv, ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
int64_t step) {
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(Value *iv,
ArrayRef<Value> lbs,
ArrayRef<Value> ubs,
int64_t step) {
loops.emplace_back(makeAffineLoopBuilder(iv, lbs, ubs, step));
}
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps) {
MutableArrayRef<Value> ivs, ArrayRef<Value> lbs, ArrayRef<Value> ubs,
ArrayRef<int64_t> steps) {
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
for (auto it : llvm::zip(ivs, lbs, ubs, steps))
loops.emplace_back(makeAffineLoopBuilder(std::get<0>(it), std::get<1>(it),
loops.emplace_back(makeAffineLoopBuilder(&std::get<0>(it), std::get<1>(it),
std::get<2>(it), std::get<3>(it)));
}
@@ -89,11 +85,6 @@ void mlir::edsc::AffineLoopNestBuilder::operator()(
(*lit)();
}
template <typename Op>
static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) {
return ValueHandle::create<Op>(lhs.getValue(), rhs.getValue());
}
static std::pair<AffineExpr, Value>
categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
unsigned &numSymbols) {
@@ -111,115 +102,109 @@ categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
return std::make_pair(d, resultVal);
}
static ValueHandle createBinaryIndexHandle(
ValueHandle lhs, ValueHandle rhs,
static Value createBinaryIndexHandle(
Value lhs, Value rhs,
function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
MLIRContext *context = ScopedContext::getContext();
unsigned numDims = 0, numSymbols = 0;
AffineExpr d0, d1;
Value v0, v1;
std::tie(d0, v0) =
categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols);
categorizeValueByAffineType(context, lhs, numDims, numSymbols);
std::tie(d1, v1) =
categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols);
categorizeValueByAffineType(context, rhs, numDims, numSymbols);
SmallVector<Value, 2> operands;
if (v0) {
if (v0)
operands.push_back(v0);
}
if (v1) {
if (v1)
operands.push_back(v1);
}
auto map = AffineMap::get(numDims, numSymbols, affCombiner(d0, d1));
// TODO: createOrFold when available.
Operation *op =
makeComposedAffineApply(ScopedContext::getBuilder(),
ScopedContext::getLocation(), map, operands)
.getOperation();
assert(op->getNumResults() == 1 && "Expected single result AffineApply");
return ValueHandle(op->getResult(0));
return op->getResult(0);
}
template <typename IOp, typename FOp>
static ValueHandle createBinaryHandle(
ValueHandle lhs, ValueHandle rhs,
static Value createBinaryHandle(
Value lhs, Value rhs,
function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
auto thisType = lhs.getValue().getType();
auto thatType = rhs.getValue().getType();
auto thisType = lhs.getType();
auto thatType = rhs.getType();
assert(thisType == thatType && "cannot mix types in operators");
(void)thisType;
(void)thatType;
if (thisType.isIndex()) {
return createBinaryIndexHandle(lhs, rhs, affCombiner);
} else if (thisType.isSignlessInteger()) {
return createBinaryHandle<IOp>(lhs, rhs);
return ValueBuilder<IOp>(lhs, rhs);
} else if (thisType.isa<FloatType>()) {
return createBinaryHandle<FOp>(lhs, rhs);
return ValueBuilder<FOp>(lhs, rhs);
} else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) {
auto aggregateType = thisType.cast<ShapedType>();
if (aggregateType.getElementType().isSignlessInteger())
return createBinaryHandle<IOp>(lhs, rhs);
return ValueBuilder<IOp>(lhs, rhs);
else if (aggregateType.getElementType().isa<FloatType>())
return createBinaryHandle<FOp>(lhs, rhs);
return ValueBuilder<FOp>(lhs, rhs);
}
llvm_unreachable("failed to create a ValueHandle");
llvm_unreachable("failed to create a Value");
}
ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator+(Value lhs, Value rhs) {
return createBinaryHandle<AddIOp, AddFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
}
ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator-(Value lhs, Value rhs) {
return createBinaryHandle<SubIOp, SubFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
}
ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator*(Value lhs, Value rhs) {
return createBinaryHandle<MulIOp, MulFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
}
ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator/(Value lhs, Value rhs) {
return createBinaryHandle<SignedDivIOp, DivFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
llvm_unreachable("only exprs of non-index type support operator/");
});
}
ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator%(Value lhs, Value rhs) {
return createBinaryHandle<SignedRemIOp, RemFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
}
ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::floorDiv(Value lhs, Value rhs) {
return createBinaryIndexHandle(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
}
ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::ceilDiv(Value lhs, Value rhs) {
return createBinaryIndexHandle(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
}
ValueHandle mlir::edsc::op::operator!(ValueHandle value) {
assert(value.getType().isInteger(1) && "expected boolean expression");
return ValueHandle::create<ConstantIntOp>(1, 1) - value;
}
ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator&&(Value lhs, Value rhs) {
assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
return ValueHandle::create<AndOp>(lhs, rhs);
return ValueBuilder<AndOp>(lhs, rhs);
}
ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator||(Value lhs, Value rhs) {
assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
return ValueHandle::create<OrOp>(lhs, rhs);
return ValueBuilder<OrOp>(lhs, rhs);
}
static ValueHandle createIComparisonExpr(CmpIPredicate predicate,
ValueHandle lhs, ValueHandle rhs) {
static Value createIComparisonExpr(CmpIPredicate predicate, Value lhs,
Value rhs) {
auto lhsType = lhs.getType();
auto rhsType = rhs.getType();
(void)lhsType;
@@ -228,13 +213,12 @@ static ValueHandle createIComparisonExpr(CmpIPredicate predicate,
assert((lhsType.isa<IndexType>() || lhsType.isSignlessInteger()) &&
"only integer comparisons are supported");
auto op = ScopedContext::getBuilder().create<CmpIOp>(
ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
return ValueHandle(op.getResult());
return ScopedContext::getBuilder().create<CmpIOp>(
ScopedContext::getLocation(), predicate, lhs, rhs);
}
static ValueHandle createFComparisonExpr(CmpFPredicate predicate,
ValueHandle lhs, ValueHandle rhs) {
static Value createFComparisonExpr(CmpFPredicate predicate, Value lhs,
Value rhs) {
auto lhsType = lhs.getType();
auto rhsType = rhs.getType();
(void)lhsType;
@@ -242,25 +226,24 @@ static ValueHandle createFComparisonExpr(CmpFPredicate predicate,
assert(lhsType == rhsType && "cannot mix types in operators");
assert(lhsType.isa<FloatType>() && "only float comparisons are supported");
auto op = ScopedContext::getBuilder().create<CmpFOp>(
ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
return ValueHandle(op.getResult());
return ScopedContext::getBuilder().create<CmpFOp>(
ScopedContext::getLocation(), predicate, lhs, rhs);
}
// All floating point comparison are ordered through EDSL
ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::eq(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::eq, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::ne(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator<(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
@@ -268,19 +251,19 @@ ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
// TODO(ntv,zinenko): signed by default, how about unsigned?
createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator<=(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator>(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
Value mlir::edsc::op::operator>=(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)

View File

@@ -44,14 +44,14 @@ static void insertCopyLoops(OpBuilder &builder, Location loc,
MemRefBoundsCapture &bounds, Value from, Value to) {
// Create EDSC handles for bounds.
unsigned rank = bounds.rank();
SmallVector<ValueHandle, 4> lbs, ubs, steps;
SmallVector<Value, 4> lbs, ubs, steps;
// Make sure we have enough loops to use all thread dimensions, these trivial
// loops should be outermost and therefore inserted first.
if (rank < GPUDialect::getNumWorkgroupDimensions()) {
unsigned extraLoops = GPUDialect::getNumWorkgroupDimensions() - rank;
ValueHandle zero = std_constant_index(0);
ValueHandle one = std_constant_index(1);
Value zero = std_constant_index(0);
Value one = std_constant_index(1);
lbs.resize(extraLoops, zero);
ubs.resize(extraLoops, one);
steps.resize(extraLoops, one);
@@ -78,9 +78,8 @@ static void insertCopyLoops(OpBuilder &builder, Location loc,
}
// Produce the loop nest with copies.
SmallVector<ValueHandle, 8> ivs(lbs.size(), ValueHandle(indexType));
auto ivPtrs = makeHandlePointers(MutableArrayRef<ValueHandle>(ivs));
LoopNestBuilder(ivPtrs, lbs, ubs, steps)([&]() {
SmallVector<Value, 8> ivs(lbs.size());
LoopNestBuilder(ivs, lbs, ubs, steps)([&]() {
auto activeIvs = llvm::makeArrayRef(ivs).take_back(rank);
StdIndexedValue fromHandle(from), toHandle(to);
toHandle(activeIvs) = fromHandle(activeIvs);
@@ -90,8 +89,8 @@ static void insertCopyLoops(OpBuilder &builder, Location loc,
for (auto en :
llvm::enumerate(llvm::reverse(llvm::makeArrayRef(ivs).take_back(
GPUDialect::getNumWorkgroupDimensions())))) {
auto loop = cast<loop::ForOp>(
en.value().getValue().getParentRegion()->getParentOp());
Value v = en.value();
auto loop = cast<loop::ForOp>(v.getParentRegion()->getParentOp());
mapLoopToProcessorIds(loop, {threadIds[en.index()]},
{blockDims[en.index()]});
}

View File

@@ -21,69 +21,61 @@ using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using namespace mlir::loop;
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
ValueHandle range) {
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(Value *iv, Value range) {
assert(range.getType() && "expected !linalg.range type");
assert(range.getValue().getDefiningOp() &&
"need operations to extract range parts");
auto rangeOp = cast<RangeOp>(range.getValue().getDefiningOp());
assert(range.getDefiningOp() && "need operations to extract range parts");
auto rangeOp = cast<RangeOp>(range.getDefiningOp());
auto lb = rangeOp.min();
auto ub = rangeOp.max();
auto step = rangeOp.step();
auto forOp = OperationHandle::createOp<ForOp>(lb, ub, step);
*iv = ValueHandle(forOp.getInductionVar());
*iv = forOp.getInductionVar();
auto *body = forOp.getBody();
enter(body, /*prev=*/1);
}
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(Value *iv,
SubViewOp::Range range) {
auto forOp =
OperationHandle::createOp<ForOp>(range.offset, range.size, range.stride);
*iv = ValueHandle(forOp.getInductionVar());
*iv = forOp.getInductionVar();
auto *body = forOp.getBody();
enter(body, /*prev=*/1);
}
ValueHandle
mlir::edsc::LoopRangeBuilder::operator()(std::function<void(void)> fun) {
Value mlir::edsc::LoopRangeBuilder::operator()(std::function<void(void)> fun) {
if (fun)
fun();
exit();
return ValueHandle::null();
return Value();
}
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
ArrayRef<ValueHandle *> ivs, ArrayRef<SubViewOp::Range> ranges) {
MutableArrayRef<Value> ivs, ArrayRef<SubViewOp::Range> ranges) {
loops.reserve(ranges.size());
for (unsigned i = 0, e = ranges.size(); i < e; ++i) {
loops.emplace_back(ivs[i], ranges[i]);
loops.emplace_back(&ivs[i], ranges[i]);
}
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
}
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> ranges) {
MutableArrayRef<Value> ivs, ArrayRef<Value> ranges) {
loops.reserve(ranges.size());
for (unsigned i = 0, e = ranges.size(); i < e; ++i) {
loops.emplace_back(ivs[i], ranges[i]);
loops.emplace_back(&ivs[i], ranges[i]);
}
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
}
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
ArrayRef<ValueHandle *> ivs, ArrayRef<Value> ranges)
: LoopNestRangeBuilder(
ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
Value LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
std::function<void(void)> fun) {
if (fun)
fun();
for (auto &lit : reverse(loops)) {
lit({});
}
return ValueHandle::null();
return Value();
}
namespace mlir {
@@ -91,15 +83,15 @@ namespace edsc {
template <>
GenericLoopNestRangeBuilder<loop::ForOp>::GenericLoopNestRangeBuilder(
ArrayRef<edsc::ValueHandle *> ivs, ArrayRef<Value> ranges) {
MutableArrayRef<Value> ivs, ArrayRef<Value> ranges) {
builder = std::make_unique<LoopNestRangeBuilder>(ivs, ranges);
}
template <>
GenericLoopNestRangeBuilder<AffineForOp>::GenericLoopNestRangeBuilder(
ArrayRef<ValueHandle *> ivs, ArrayRef<Value> ranges) {
SmallVector<ValueHandle, 4> lbs;
SmallVector<ValueHandle, 4> ubs;
MutableArrayRef<Value> ivs, ArrayRef<Value> ranges) {
SmallVector<Value, 4> lbs;
SmallVector<Value, 4> ubs;
SmallVector<int64_t, 4> steps;
for (Value range : ranges) {
assert(range.getType() && "expected linalg.range type");
@@ -114,8 +106,8 @@ GenericLoopNestRangeBuilder<AffineForOp>::GenericLoopNestRangeBuilder(
template <>
GenericLoopNestRangeBuilder<loop::ParallelOp>::GenericLoopNestRangeBuilder(
ArrayRef<ValueHandle *> ivs, ArrayRef<Value> ranges) {
SmallVector<ValueHandle, 4> lbs, ubs, steps;
MutableArrayRef<Value> ivs, ArrayRef<Value> ranges) {
SmallVector<Value, 4> lbs, ubs, steps;
for (Value range : ranges) {
assert(range.getType() && "expected linalg.range type");
assert(range.getDefiningOp() && "need operations to extract range parts");
@@ -197,10 +189,9 @@ Operation *mlir::edsc::makeGenericLinalgOp(
OpBuilder opBuilder(op);
ScopedContext scope(opBuilder, op->getLoc());
BlockHandle b;
auto handles = makeValueHandles(blockTypes);
BlockBuilder(&b, op->getRegion(0),
makeHandlePointers(MutableArrayRef<ValueHandle>(handles)))(
[&] { regionBuilder(b.getBlock()->getArguments()); });
SmallVector<Value, 8> handles(blockTypes.size());
BlockBuilder(&b, op->getRegion(0), blockTypes,
handles)([&] { regionBuilder(b.getBlock()->getArguments()); });
assert(op->getRegion(0).getBlocks().size() == 1);
return op;
}
@@ -209,16 +200,16 @@ void mlir::edsc::ops::mulRegionBuilder(ArrayRef<BlockArgument> args) {
using edsc::op::operator+;
using edsc::op::operator*;
assert(args.size() == 2 && "expected 2 block arguments");
ValueHandle a(args[0]), b(args[1]);
linalg_yield((a * b).getValue());
Value a(args[0]), b(args[1]);
linalg_yield(a * b);
}
void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
using edsc::op::operator+;
using edsc::op::operator*;
assert(args.size() == 3 && "expected 3 block arguments");
ValueHandle a(args[0]), b(args[1]), c(args[2]);
linalg_yield((c + a * b).getValue());
Value a(args[0]), b(args[1]), c(args[2]);
linalg_yield(c + a * b);
}
Operation *mlir::edsc::ops::linalg_generic_pointwise(
@@ -228,14 +219,14 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
if (O.getType().isa<RankedTensorType>()) {
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
assert(args.size() == 1 && "expected 1 block arguments");
ValueHandle a(args[0]);
Value a(args[0]);
linalg_yield(unaryOp(a));
};
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
}
auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
assert(args.size() == 2 && "expected 2 block arguments");
ValueHandle a(args[0]);
Value a(args[0]);
linalg_yield(unaryOp(a));
};
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
@@ -243,8 +234,7 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I,
StructuredIndexed O) {
UnaryPointwiseOpBuilder unOp(
[](ValueHandle a) -> Value { return std_tanh(a); });
UnaryPointwiseOpBuilder unOp([](Value a) -> Value { return std_tanh(a); });
return linalg_generic_pointwise(unOp, I, O);
}
@@ -257,14 +247,14 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
if (O.getType().isa<RankedTensorType>()) {
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
assert(args.size() == 2 && "expected 2 block arguments");
ValueHandle a(args[0]), b(args[1]);
Value a(args[0]), b(args[1]);
linalg_yield(binaryOp(a, b));
};
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
}
auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
assert(args.size() == 3 && "expected 3 block arguments");
ValueHandle a(args[0]), b(args[1]);
Value a(args[0]), b(args[1]);
linalg_yield(binaryOp(a, b));
};
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
@@ -275,23 +265,22 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1,
StructuredIndexed O) {
using edsc::op::operator+;
BinaryPointwiseOpBuilder binOp(
[](ValueHandle a, ValueHandle b) -> Value { return a + b; });
[](Value a, Value b) -> Value { return a + b; });
return linalg_generic_pointwise(binOp, I1, I2, O);
}
Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value {
using edsc::op::operator>;
return std_select(a > b, a, b).getValue();
return std_select(a > b, a, b);
});
return linalg_generic_pointwise(binOp, I1, I2, O);
}
Operation *
mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
ValueHandle vC,
mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
@@ -306,8 +295,7 @@ mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
}
Operation *
mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
RankedTensorType tC,
mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC,
MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
@@ -322,8 +310,8 @@ mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
}
Operation *
mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
ValueHandle vC, RankedTensorType tD,
mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
RankedTensorType tD,
MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
@@ -337,9 +325,8 @@ mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB,
// clang-format on
}
Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(ValueHandle vI,
ValueHandle vW,
ValueHandle vO,
Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW,
Value vO,
ArrayRef<int> strides,
ArrayRef<int> dilations) {
MLIRContext *ctx = ScopedContext::getContext();
@@ -373,8 +360,8 @@ Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(ValueHandle vI,
}
Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc(
ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
ArrayRef<int> strides, ArrayRef<int> dilations) {
Value vI, Value vW, Value vO, int depth_multiplier, ArrayRef<int> strides,
ArrayRef<int> dilations) {
MLIRContext *ctx = ScopedContext::getContext();
// TODO(ntv) some template magic to make everything rank-polymorphic.
assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");

View File

@@ -35,7 +35,7 @@ using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>;
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
using llvm::dbgs;

View File

@@ -29,17 +29,16 @@ using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using edsc::op::operator+;
using edsc::op::operator==;
using mlir::edsc::intrinsics::detail::ValueHandleArray;
static SmallVector<ValueHandle, 8>
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<Value> vals) {
static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
Location loc,
AffineMap map,
ArrayRef<Value> vals) {
if (map.isEmpty())
return {};
assert(map.getNumSymbols() == 0);
assert(map.getNumInputs() == vals.size());
SmallVector<ValueHandle, 8> res;
SmallVector<Value, 8> res;
res.reserve(map.getNumResults());
auto dims = map.getNumDims();
for (auto e : map.getResults()) {
@@ -80,10 +79,10 @@ SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
}
template <typename OpType>
static void inlineRegionAndEmitStdStore(OpType op,
ArrayRef<Value> indexedValues,
ArrayRef<ValueHandleArray> indexing,
ArrayRef<Value> outputBuffers) {
static void
inlineRegionAndEmitStdStore(OpType op, ArrayRef<Value> indexedValues,
ArrayRef<SmallVector<Value, 8>> indexing,
ArrayRef<Value> outputBuffers) {
auto &b = ScopedContext::getBuilder();
auto &block = op.region().front();
BlockAndValueMapping map;
@@ -99,25 +98,27 @@ static void inlineRegionAndEmitStdStore(OpType op,
"expected an yield op in the end of the region");
for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
std_store(map.lookupOrDefault(terminator.getOperand(i)), outputBuffers[i],
indexing[i]);
ArrayRef<Value>{indexing[i].begin(), indexing[i].end()});
}
}
// Returns a pair that contains input indices and output indices of a
// SingleInputPoolingOp `op`.
struct InputAndOutputIndices {
SmallVector<Value, 8> inputs;
SmallVector<Value, 8> outputs;
};
template <typename SingleInputPoolingOp>
static std::pair<SmallVector<ValueHandle, 8>, SmallVector<ValueHandle, 8>>
getInputAndOutputIndices(ArrayRef<Value> allIvs, SingleInputPoolingOp op) {
static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
SingleInputPoolingOp op) {
auto &b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>();
auto maps = llvm::to_vector<8>(
llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
SmallVector<ValueHandle, 8> iIdx(
makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
SmallVector<ValueHandle, 8> oIdx(
makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
return {iIdx, oIdx};
return InputAndOutputIndices{
makeCanonicalAffineApplies(b, loc, maps[0], allIvs),
makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
}
namespace {
@@ -150,8 +151,8 @@ public:
permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
auto outputIvs =
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
SmallVector<ValueHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<ValueHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
@@ -170,13 +171,11 @@ public:
"expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs =
SmallVector<ValueHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
IndexedValueType O(fillOp.getOutputBuffer(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
nPar > 0 ? O(ivs) = ValueHandle(fillOp.value())
: O() = ValueHandle(fillOp.value());
nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
}
};
@@ -187,7 +186,7 @@ public:
assert(dotOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
assert(allIvs.size() == 1);
ValueHandle r_i(allIvs[0]);
Value r_i(allIvs[0]);
IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
C(dotOp.getOutputBuffer(0));
// Emit scalar form.
@@ -203,7 +202,7 @@ public:
assert(matvecOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
assert(allIvs.size() == 2);
ValueHandle i(allIvs[0]), r_j(allIvs[1]);
Value i(allIvs[0]), r_j(allIvs[1]);
IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
C(matvecOp.getOutputBuffer(0));
// Emit scalar form.
@@ -219,7 +218,7 @@ public:
assert(matmulOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
assert(allIvs.size() == 3);
ValueHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
Value i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
C(matmulOp.getOutputBuffer(0));
// Emit scalar form.
@@ -232,16 +231,16 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
public:
/// Returns the input value of convOp. If the indices in `imIdx` is out of
/// boundary, returns 0 instead.
static ValueHandle getConvOpInput(ConvOp convOp, IndexedValueType im,
ArrayRef<ValueHandle> imIdx) {
static Value getConvOpInput(ConvOp convOp, IndexedValueType im,
MutableArrayRef<Value> imIdx) {
// TODO(ntv): add a level of indirection to linalg.generic.
if (!convOp.padding())
return im(imIdx);
auto *context = ScopedContext::getContext();
ValueHandle zeroIndex = std_constant_index(0);
SmallVector<ValueHandle, 8> conds;
SmallVector<ValueHandle, 8> clampedImIdx;
Value zeroIndex = std_constant_index(0);
SmallVector<Value, 8> conds;
SmallVector<Value, 8> clampedImIdx;
for (auto iter : llvm::enumerate(imIdx)) {
int idx = iter.index();
auto dim = iter.value();
@@ -254,12 +253,12 @@ public:
using edsc::op::operator<;
using edsc::op::operator>=;
using edsc::op::operator||;
ValueHandle leftOutOfBound = dim < zeroIndex;
Value leftOutOfBound = dim < zeroIndex;
if (conds.empty())
conds.push_back(leftOutOfBound);
else
conds.push_back(conds.back() || leftOutOfBound);
ValueHandle rightBound = std_dim(convOp.input(), idx);
Value rightBound = std_dim(convOp.input(), idx);
conds.push_back(conds.back() || (dim >= rightBound));
// When padding is involved, the indices will only be shifted to negative,
@@ -274,10 +273,10 @@ public:
auto b = ScopedContext::getBuilder();
Type type = convOp.input().getType().cast<MemRefType>().getElementType();
ValueHandle zero = std_constant(type, b.getZeroAttr(type));
ValueHandle readInput = im(clampedImIdx);
Value zero = std_constant(type, b.getZeroAttr(type));
Value readInput = im(clampedImIdx);
return conds.empty() ? readInput
: std_select(conds.back(), zero, readInput);
: (Value)std_select(conds.back(), zero, readInput);
}
static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
@@ -288,16 +287,16 @@ public:
auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>();
auto maps = llvm::to_vector<8>(llvm::map_range(
mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
SmallVector<ValueHandle, 8> fIdx(
SmallVector<Value, 8> fIdx(
makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
SmallVector<ValueHandle, 8> imIdx(
SmallVector<Value, 8> imIdx(
makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
SmallVector<ValueHandle, 8> oIdx(
SmallVector<Value, 8> oIdx(
makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
// Emit scalar form.
ValueHandle paddedInput = getConvOpInput(convOp, I, imIdx);
Value paddedInput = getConvOpInput(convOp, I, imIdx);
O(oIdx) += F(fIdx) * paddedInput;
}
};
@@ -308,15 +307,12 @@ public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
PoolingMaxOp op) {
auto indices = getInputAndOutputIndices(allIvs, op);
ValueHandleArray iIdx(indices.first);
ValueHandleArray oIdx(indices.second);
// Emit scalar form.
ValueHandle lhs = std_load(op.output(), oIdx);
ValueHandle rhs = std_load(op.input(), iIdx);
Value lhs = std_load(op.output(), indices.outputs);
Value rhs = std_load(op.input(), indices.inputs);
using edsc::op::operator>;
ValueHandle maxValue = std_select(lhs > rhs, lhs, rhs);
std_store(maxValue, op.output(), oIdx);
Value maxValue = std_select(lhs > rhs, lhs, rhs);
std_store(maxValue, op.output(), indices.outputs);
}
};
@@ -326,15 +322,12 @@ public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
PoolingMinOp op) {
auto indices = getInputAndOutputIndices(allIvs, op);
ValueHandleArray iIdx(indices.first);
ValueHandleArray oIdx(indices.second);
// Emit scalar form.
ValueHandle lhs = std_load(op.output(), oIdx);
ValueHandle rhs = std_load(op.input(), iIdx);
Value lhs = std_load(op.output(), indices.outputs);
Value rhs = std_load(op.input(), indices.inputs);
using edsc::op::operator<;
ValueHandle minValue = std_select(lhs < rhs, lhs, rhs);
std_store(minValue, op.output(), oIdx);
Value minValue = std_select(lhs < rhs, lhs, rhs);
std_store(minValue, op.output(), indices.outputs);
}
};
@@ -344,12 +337,10 @@ public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
PoolingSumOp op) {
auto indices = getInputAndOutputIndices(allIvs, op);
SmallVector<ValueHandle, 8> iIdx = indices.first;
SmallVector<ValueHandle, 8> oIdx = indices.second;
IndexedValueType input(op.input()), output(op.output());
// Emit scalar form.
output(oIdx) += input(iIdx);
output(indices.outputs) += input(indices.inputs);
}
};
@@ -392,15 +383,14 @@ public:
"expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
unsigned nInputs = genericOp.getNumInputs();
unsigned nOutputs = genericOp.getNumOutputs();
SmallVector<Value, 4> indexedValues(nInputs + nOutputs);
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs));
auto indexing = makeCanonicalAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs);
indexedValues[i] = std_load(genericOp.getInput(i), indexing);
}
@@ -409,18 +399,18 @@ public:
// region has no uses.
for (unsigned i = 0; i < nOutputs; ++i) {
Value output = genericOp.getOutputBuffer(i);
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
auto indexing = makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs);
indexedValues[nInputs + i] = std_load(output, indexing);
}
// TODO(ntv): When a region inliner exists, use it.
// 2. Inline region, currently only works for a single basic block.
// 3. Emit std_store.
SmallVector<ValueHandleArray, 8> indexing;
SmallVector<SmallVector<Value, 8>, 8> indexing;
SmallVector<Value, 8> outputBuffers;
for (unsigned i = 0; i < nOutputs; ++i) {
indexing.emplace_back(makeCanonicalAffineApplies(
indexing.push_back(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
outputBuffers.push_back(genericOp.getOutputBuffer(i));
}
@@ -468,7 +458,6 @@ public:
"expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
unsigned nInputs = indexedGenericOp.getNumInputs();
unsigned nOutputs = indexedGenericOp.getNumOutputs();
unsigned nLoops = allIvs.size();
@@ -481,26 +470,26 @@ public:
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
Value input = indexedGenericOp.getInput(i);
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
auto indexing = makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs);
indexedValues[nLoops + i] = std_load(input, indexing);
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
Value output = indexedGenericOp.getOutputBuffer(i);
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
auto indexing = makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs);
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
}
// TODO(ntv): When a region inliner exists, use it.
// 2. Inline region, currently only works for a single basic block.
// 3. Emit std_store.
SmallVector<ValueHandleArray, 8> indexing;
SmallVector<SmallVector<Value, 8>, 8> indexing;
SmallVector<Value, 8> outputBuffers;
for (unsigned i = 0; i < nOutputs; ++i) {
indexing.emplace_back(makeCanonicalAffineApplies(
indexing.push_back(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
}
@@ -533,11 +522,8 @@ public:
typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
AffineIndexedValue, StdIndexedValue>::type;
static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
MutableArrayRef<ValueHandle> allIvs) {
SmallVector<ValueHandle *, 4> allPIvs =
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
GenericLoopNestRangeBuilder<LoopTy>(allPIvs, loopRanges)([&] {
MutableArrayRef<Value> allIvs) {
GenericLoopNestRangeBuilder<LoopTy>(allIvs, loopRanges)([&] {
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
LinalgScopedEmitter<IndexedValueTy,
ConcreteOpTy>::emitScalarImplementation(allIvValues,
@@ -555,7 +541,7 @@ public:
using IndexedValueTy = StdIndexedValue;
static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
MutableArrayRef<ValueHandle> allIvs) {
MutableArrayRef<Value> allIvs) {
// Only generate loop.parallel for outer consecutive "parallel"
// iterator_types.
// TODO(ravishankarm): Generate loop.parallel for all "parallel" iterator
@@ -575,24 +561,18 @@ public:
// If there are no outer parallel loops, then number of loop ops is same as
// the number of loops, and they are all loop.for ops.
auto nLoopOps = (nOuterPar ? nLoops - nOuterPar + 1 : nLoops);
SmallVector<ValueHandle *, 4> allPIvs =
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
SmallVector<OperationHandle, 4> allLoops(nLoopOps, OperationHandle());
SmallVector<OperationHandle *, 4> allPLoops;
allPLoops.reserve(allLoops.size());
for (OperationHandle &loop : allLoops)
allPLoops.push_back(&loop);
ArrayRef<ValueHandle *> allPIvsRef(allPIvs);
ArrayRef<OperationHandle *> allPLoopsRef(allPLoops);
if (nOuterPar) {
GenericLoopNestRangeBuilder<loop::ParallelOp>(
allPIvsRef.take_front(nOuterPar),
loopRanges.take_front(nOuterPar))([&] {
allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar))([&] {
GenericLoopNestRangeBuilder<loop::ForOp>(
allPIvsRef.drop_front(nOuterPar),
allIvs.drop_front(nOuterPar),
loopRanges.drop_front(nOuterPar))([&] {
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
LinalgScopedEmitter<StdIndexedValue, ConcreteOpTy>::
@@ -602,7 +582,7 @@ public:
} else {
// If there are no parallel loops then fallback to generating all loop.for
// operations.
GenericLoopNestRangeBuilder<loop::ForOp>(allPIvsRef, loopRanges)([&] {
GenericLoopNestRangeBuilder<loop::ForOp>(allIvs, loopRanges)([&] {
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
LinalgScopedEmitter<StdIndexedValue,
ConcreteOpTy>::emitScalarImplementation(allIvValues,
@@ -645,8 +625,7 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
return LinalgLoops();
}
SmallVector<ValueHandle, 4> allIvs(nLoops,
ValueHandle(rewriter.getIndexType()));
SmallVector<Value, 4> allIvs(nLoops);
auto loopRanges =
emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
getViewSizes(rewriter, linalgOp));
@@ -655,12 +634,12 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
// Number of loop ops might be different from the number of ivs since some
// loops like affine.parallel and loop.parallel have multiple ivs.
llvm::SetVector<Operation *> loopSet;
for (ValueHandle &iv : allIvs) {
if (!iv.hasValue())
for (Value iv : allIvs) {
if (!iv)
return {};
// The induction variable is a block argument of the entry block of the
// loop operation.
BlockArgument ivVal = iv.getValue().dyn_cast<BlockArgument>();
BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
if (!ivVal)
return {};
loopSet.insert(ivVal.getOwner()->getParentOp());

View File

@@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
@@ -219,10 +220,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
Operation *op) {
using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
using vector_broadcast = edsc::intrinsics::ValueBuilder<vector::BroadcastOp>;
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
assert(succeeded(vectorizeLinalgOpPrecondition(op)) &&
"DRR failure case must be a precondition");
auto linalgOp = cast<linalg::LinalgOp>(op);
@@ -242,8 +239,8 @@ SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
"]: Rewrite linalg.fill as vector.broadcast: "
<< *op << ":\n");
auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0));
auto dstVec = std_load(dstMemrefVec);
auto resVec = vector_broadcast(dstVec, fillOp.value());
Value dstVec = std_load(dstMemrefVec);
auto resVec = vector_broadcast(dstVec.getType(), fillOp.value());
std_store(resVec, dstMemrefVec);
} else {
// Vectorize other ops as vector contraction (currently only matmul).

View File

@@ -36,11 +36,11 @@ using namespace mlir::loop;
using llvm::SetVector;
using folded_affine_min = folded::ValueBuilder<AffineMinOp>;
using folded_linalg_range = folded::ValueBuilder<linalg::RangeOp>;
using folded_std_dim = folded::ValueBuilder<DimOp>;
using folded_std_subview = folded::ValueBuilder<SubViewOp>;
using folded_std_view = folded::ValueBuilder<ViewOp>;
using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
using folded_linalg_range = FoldedValueBuilder<linalg::RangeOp>;
using folded_std_dim = FoldedValueBuilder<DimOp>;
using folded_std_subview = FoldedValueBuilder<SubViewOp>;
using folded_std_view = FoldedValueBuilder<ViewOp>;
#define DEBUG_TYPE "linalg-promotion"
@@ -74,8 +74,8 @@ static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
if (!dynamicBuffers)
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
return std_alloc(
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)), {},
alignment_attr);
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)),
ValueRange{}, alignment_attr);
Value mul =
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
@@ -118,7 +118,7 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
auto rangeValue = en.value();
// Try to extract a tight constant
Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
allocSize = folded_std_muli(folder, allocSize, size).getValue();
allocSize = folded_std_muli(folder, allocSize, size);
fullSizes.push_back(size);
partialSizes.push_back(folded_std_dim(folder, subView, rank));
}

View File

@@ -32,7 +32,7 @@ using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using namespace mlir::loop;
using folded_affine_min = folded::ValueBuilder<AffineMinOp>;
using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
#define DEBUG_TYPE "linalg-tiling"
@@ -163,7 +163,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
// TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices
// does not lead to losing information.
static void transformIndexedGenericOpIndices(
OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
@@ -193,7 +193,7 @@ static void transformIndexedGenericOpIndices(
// Offset the index argument `i` by the value of the corresponding induction
// variable and replace all uses of the previous value.
Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
pivs[rangeIndex->second]->getValue());
ivs[rangeIndex->second]);
for (auto &use : oldIndex.getUses()) {
if (use.getOwner() == newIndex.getDefiningOp())
continue;
@@ -376,15 +376,14 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
// 3. Create the tiled loops.
LinalgOp res = op;
auto ivs = ValueHandle::makeIndexHandles(loopRanges.size());
auto pivs = makeHandlePointers(MutableArrayRef<ValueHandle>(ivs));
SmallVector<Value, 4> ivs(loopRanges.size());
// Convert SubViewOp::Range to linalg_range.
SmallVector<Value, 4> linalgRanges;
for (auto &range : loopRanges) {
linalgRanges.push_back(
linalg_range(range.offset, range.size, range.stride));
}
GenericLoopNestRangeBuilder<LoopTy>(pivs, linalgRanges)([&] {
GenericLoopNestRangeBuilder<LoopTy>(ivs, linalgRanges)([&] {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());
@@ -405,7 +404,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
});
// 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
transformIndexedGenericOpIndices(b, res, pivs, loopIndexToRangeIndex);
transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex);
// 5. Gather the newly created loops and return them with the new op.
SmallVector<Operation *, 8> loops;

View File

@@ -14,8 +14,8 @@ using namespace mlir;
using namespace mlir::edsc;
mlir::edsc::ParallelLoopNestBuilder::ParallelLoopNestBuilder(
ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<ValueHandle> steps) {
MutableArrayRef<Value> ivs, ArrayRef<Value> lbs, ArrayRef<Value> ubs,
ArrayRef<Value> steps) {
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
@@ -36,29 +36,34 @@ void mlir::edsc::ParallelLoopNestBuilder::operator()(
(*lit)();
}
mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs,
ArrayRef<ValueHandle> steps) {
mlir::edsc::LoopNestBuilder::LoopNestBuilder(MutableArrayRef<Value> ivs,
ArrayRef<Value> lbs,
ArrayRef<Value> ubs,
ArrayRef<Value> steps) {
assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match");
assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match");
assert(ivs.size() == steps.size() &&
"expected size of ivs and steps to match");
loops.reserve(ivs.size());
for (auto it : llvm::zip(ivs, lbs, ubs, steps))
loops.emplace_back(makeLoopBuilder(std::get<0>(it), std::get<1>(it),
loops.emplace_back(makeLoopBuilder(&std::get<0>(it), std::get<1>(it),
std::get<2>(it), std::get<3>(it)));
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
}
mlir::edsc::LoopNestBuilder::LoopNestBuilder(
ValueHandle *iv, ValueHandle lb, ValueHandle ub, ValueHandle step,
ArrayRef<ValueHandle *> iter_args_handles,
ValueRange iter_args_init_values) {
assert(iter_args_init_values.size() == iter_args_handles.size() &&
Value *iv, Value lb, Value ub, Value step,
MutableArrayRef<Value> iterArgsHandles, ValueRange iterArgsInitValues) {
assert(iterArgsInitValues.size() == iterArgsHandles.size() &&
"expected size of arguments and argument_handles to match");
loops.emplace_back(makeLoopBuilder(iv, lb, ub, step, iter_args_handles,
iter_args_init_values));
loops.emplace_back(
makeLoopBuilder(iv, lb, ub, step, iterArgsHandles, iterArgsInitValues));
}
mlir::edsc::LoopNestBuilder::LoopNestBuilder(Value *iv, Value lb, Value ub,
Value step) {
SmallVector<Value, 0> noArgs;
loops.emplace_back(makeLoopBuilder(iv, lb, ub, step, noArgs, {}));
}
Operation::result_range
@@ -73,10 +78,10 @@ mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()(
return loops[0].getOp()->getResults();
}
LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles,
ArrayRef<ValueHandle> steps) {
LoopBuilder mlir::edsc::makeParallelLoopBuilder(MutableArrayRef<Value> ivs,
ArrayRef<Value> lbHandles,
ArrayRef<Value> ubHandles,
ArrayRef<Value> steps) {
LoopBuilder result;
auto opHandle = OperationHandle::create<loop::ParallelOp>(
SmallVector<Value, 4>(lbHandles.begin(), lbHandles.end()),
@@ -86,24 +91,22 @@ LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef<ValueHandle *> ivs,
loop::ParallelOp parallelOp =
cast<loop::ParallelOp>(*opHandle.getOperation());
for (size_t i = 0, e = ivs.size(); i < e; ++i)
*ivs[i] = ValueHandle(parallelOp.getBody()->getArgument(i));
ivs[i] = parallelOp.getBody()->getArgument(i);
result.enter(parallelOp.getBody(), /*prev=*/1);
return result;
}
mlir::edsc::LoopBuilder
mlir::edsc::makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle,
ValueHandle ubHandle, ValueHandle stepHandle,
ArrayRef<ValueHandle *> iter_args_handles,
ValueRange iter_args_init_values) {
mlir::edsc::LoopBuilder mlir::edsc::makeLoopBuilder(
Value *iv, Value lbHandle, Value ubHandle, Value stepHandle,
MutableArrayRef<Value> iterArgsHandles, ValueRange iterArgsInitValues) {
mlir::edsc::LoopBuilder result;
auto forOp = OperationHandle::createOp<loop::ForOp>(
lbHandle, ubHandle, stepHandle, iter_args_init_values);
*iv = ValueHandle(forOp.getInductionVar());
auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody();
for (size_t i = 0, e = iter_args_handles.size(); i < e; ++i) {
lbHandle, ubHandle, stepHandle, iterArgsInitValues);
*iv = forOp.getInductionVar();
auto *body = loop::getForInductionVarOwner(*iv).getBody();
for (size_t i = 0, e = iterArgsHandles.size(); i < e; ++i) {
// Skipping the induction variable.
*(iter_args_handles[i]) = ValueHandle(body->getArgument(i + 1));
iterArgsHandles[i] = body->getArgument(i + 1);
}
result.setOp(forOp);
result.enter(body, /*prev=*/1);

View File

@@ -14,11 +14,11 @@ using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
static SmallVector<ValueHandle, 8> getMemRefSizes(Value memRef) {
static SmallVector<Value, 8> getMemRefSizes(Value memRef) {
MemRefType memRefType = memRef.getType().cast<MemRefType>();
assert(isStrided(memRefType) && "Expected strided MemRef type");
SmallVector<ValueHandle, 8> res;
SmallVector<Value, 8> res;
res.reserve(memRefType.getShape().size());
const auto &shape = memRefType.getShape();
for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) {

View File

@@ -13,45 +13,29 @@ using namespace mlir;
using namespace mlir::edsc;
OperationHandle mlir::edsc::intrinsics::std_br(BlockHandle bh,
ArrayRef<ValueHandle> operands) {
ArrayRef<Value> operands) {
assert(bh && "Expected already captured BlockHandle");
for (auto &o : operands) {
(void)o;
assert(o && "Expected already captured ValueHandle");
assert(o && "Expected already captured Value");
}
SmallVector<Value, 4> ops(operands.begin(), operands.end());
return OperationHandle::create<BranchOp>(bh.getBlock(), ops);
}
static void enforceEmptyCapturesMatchOperands(ArrayRef<ValueHandle *> captures,
ArrayRef<ValueHandle> operands) {
assert(captures.size() == operands.size() &&
"Expected same number of captures as operands");
for (auto it : llvm::zip(captures, operands)) {
(void)it;
assert(!std::get<0>(it)->hasValue() &&
"Unexpected already captured ValueHandle");
assert(std::get<1>(it) && "Expected already captured ValueHandle");
assert(std::get<0>(it)->getType() == std::get<1>(it).getType() &&
"Expected the same type for capture and operand");
}
}
OperationHandle mlir::edsc::intrinsics::std_br(BlockHandle *bh,
ArrayRef<ValueHandle *> captures,
ArrayRef<ValueHandle> operands) {
ArrayRef<Type> types,
MutableArrayRef<Value> captures,
ArrayRef<Value> operands) {
assert(!*bh && "Unexpected already captured BlockHandle");
enforceEmptyCapturesMatchOperands(captures, operands);
BlockBuilder(bh, captures)(/* no body */);
BlockBuilder(bh, types, captures)(/* no body */);
SmallVector<Value, 4> ops(operands.begin(), operands.end());
return OperationHandle::create<BranchOp>(bh->getBlock(), ops);
}
OperationHandle
mlir::edsc::intrinsics::std_cond_br(ValueHandle cond, BlockHandle trueBranch,
ArrayRef<ValueHandle> trueOperands,
BlockHandle falseBranch,
ArrayRef<ValueHandle> falseOperands) {
OperationHandle mlir::edsc::intrinsics::std_cond_br(
Value cond, BlockHandle trueBranch, ArrayRef<Value> trueOperands,
BlockHandle falseBranch, ArrayRef<Value> falseOperands) {
SmallVector<Value, 4> trueOps(trueOperands.begin(), trueOperands.end());
SmallVector<Value, 4> falseOps(falseOperands.begin(), falseOperands.end());
return OperationHandle::create<CondBranchOp>(
@@ -59,16 +43,14 @@ mlir::edsc::intrinsics::std_cond_br(ValueHandle cond, BlockHandle trueBranch,
}
OperationHandle mlir::edsc::intrinsics::std_cond_br(
ValueHandle cond, BlockHandle *trueBranch,
ArrayRef<ValueHandle *> trueCaptures, ArrayRef<ValueHandle> trueOperands,
BlockHandle *falseBranch, ArrayRef<ValueHandle *> falseCaptures,
ArrayRef<ValueHandle> falseOperands) {
Value cond, BlockHandle *trueBranch, ArrayRef<Type> trueTypes,
MutableArrayRef<Value> trueCaptures, ArrayRef<Value> trueOperands,
BlockHandle *falseBranch, ArrayRef<Type> falseTypes,
MutableArrayRef<Value> falseCaptures, ArrayRef<Value> falseOperands) {
assert(!*trueBranch && "Unexpected already captured BlockHandle");
assert(!*falseBranch && "Unexpected already captured BlockHandle");
enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands);
enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands);
BlockBuilder(trueBranch, trueCaptures)(/* no body */);
BlockBuilder(falseBranch, falseCaptures)(/* no body */);
BlockBuilder(trueBranch, trueTypes, trueCaptures)(/* no body */);
BlockBuilder(falseBranch, falseTypes, falseCaptures)(/* no body */);
SmallVector<Value, 4> trueOps(trueOperands.begin(), trueOperands.end());
SmallVector<Value, 4> falseOps(falseOperands.begin(), falseOperands.end());
return OperationHandle::create<CondBranchOp>(

View File

@@ -65,25 +65,8 @@ MLIRContext *mlir::edsc::ScopedContext::getContext() {
return getBuilder().getContext();
}
ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) {
assert(t == other.t && "Wrong type capture");
assert(!v && "ValueHandle has already been captured, use a new name!");
v = other.v;
return *this;
}
ValueHandle ValueHandle::create(StringRef name, ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes) {
Operation *op =
OperationHandle::create(name, operands, resultTypes, attributes);
if (op->getNumResults() == 1)
return ValueHandle(op->getResult(0));
llvm_unreachable("unsupported operation, use an OperationHandle instead");
}
OperationHandle OperationHandle::create(StringRef name,
ArrayRef<ValueHandle> operands,
ArrayRef<Value> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes) {
OperationState state(ScopedContext::getLocation(), name);
@@ -156,37 +139,32 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) {
enter(bh.getBlock());
}
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
ArrayRef<ValueHandle *> args) {
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, ArrayRef<Type> types,
MutableArrayRef<Value> args) {
assert(!*bh && "BlockHandle already captures a block, use "
"the explicit BockBuilder(bh, Append())({}) syntax instead.");
SmallVector<Type, 8> types;
for (auto *a : args) {
assert(!a->hasValue() &&
"Expected delayed ValueHandle that has not yet captured.");
types.push_back(a->getType());
}
assert((args.empty() || args.size() == types.size()) &&
"if args captures are specified, their number must match the number "
"of types");
*bh = BlockHandle::create(types);
for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) {
*(std::get<0>(it)) = ValueHandle(std::get<1>(it));
}
if (!args.empty())
for (auto it : llvm::zip(args, bh->getBlock()->getArguments()))
std::get<0>(it) = Value(std::get<1>(it));
enter(bh->getBlock());
}
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, Region &region,
ArrayRef<ValueHandle *> args) {
ArrayRef<Type> types,
MutableArrayRef<Value> args) {
assert(!*bh && "BlockHandle already captures a block, use "
"the explicit BockBuilder(bh, Append())({}) syntax instead.");
SmallVector<Type, 8> types;
for (auto *a : args) {
assert(!a->hasValue() &&
"Expected delayed ValueHandle that has not yet captured.");
types.push_back(a->getType());
}
assert((args.empty() || args.size() == types.size()) &&
"if args captures are specified, their number must match the number "
"of types");
*bh = BlockHandle::createInRegion(region, types);
for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) {
*(std::get<0>(it)) = ValueHandle(std::get<1>(it));
}
if (!args.empty())
for (auto it : llvm::zip(args, bh->getBlock()->getArguments()))
std::get<0>(it) = Value(std::get<1>(it));
enter(bh->getBlock());
}

View File

@@ -68,12 +68,11 @@ TEST_FUNC(builder_dynamic_for_func_args) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)),
ub(f.getArgument(1));
ValueHandle f7(std_constant_float(llvm::APFloat(7.0f), f32Type));
ValueHandle f13(std_constant_float(llvm::APFloat(13.0f), f32Type));
ValueHandle i7(std_constant_int(7, 32));
ValueHandle i13(std_constant_int(13, 32));
Value i, j, lb(f.getArgument(0)), ub(f.getArgument(1));
Value f7(std_constant_float(llvm::APFloat(7.0f), f32Type));
Value f13(std_constant_float(llvm::APFloat(13.0f), f32Type));
Value i7(std_constant_int(7, 32));
Value i13(std_constant_int(13, 32));
AffineLoopNestBuilder(&i, lb, ub, 3)([&] {
using namespace edsc::op;
lb *std_constant_index(3) + ub;
@@ -119,8 +118,8 @@ TEST_FUNC(builder_dynamic_for) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
c(f.getArgument(2)), d(f.getArgument(3));
Value i, a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)),
d(f.getArgument(3));
using namespace edsc::op;
AffineLoopNestBuilder(&i, a - b, c + d, 2)();
@@ -141,8 +140,8 @@ TEST_FUNC(builder_loop_for) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
c(f.getArgument(2)), d(f.getArgument(3));
Value i, a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)),
d(f.getArgument(3));
using namespace edsc::op;
LoopNestBuilder(&i, a - b, c + d, a)();
@@ -163,8 +162,8 @@ TEST_FUNC(builder_max_min_for) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)),
ub1(f.getArgument(2)), ub2(f.getArgument(3));
Value i, lb1(f.getArgument(0)), lb2(f.getArgument(1)), ub1(f.getArgument(2)),
ub2(f.getArgument(3));
AffineLoopNestBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)();
std_ret();
@@ -183,17 +182,20 @@ TEST_FUNC(builder_blocks) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
arg4(c1.getType()), r(c1.getType());
Value c1(std_constant_int(42, 32)), c2(std_constant_int(1234, 32));
Value r;
Value args12[2];
Value &arg1 = args12[0], &arg2 = args12[1];
Value args34[2];
Value &arg3 = args34[0], &arg4 = args34[1];
BlockHandle b1, b2, functionBlock(&f.front());
BlockBuilder(&b1, {&arg1, &arg2})(
BlockBuilder(&b1, {c1.getType(), c1.getType()}, args12)(
// b2 has not yet been constructed, need to come back later.
// This is a byproduct of non-structured control-flow.
);
BlockBuilder(&b2, {&arg3, &arg4})([&] { std_br(b1, {arg3, arg4}); });
BlockBuilder(&b2, {c1.getType(), c1.getType()}, args34)([&] {
std_br(b1, {arg3, arg4});
});
// The insertion point within the toplevel function is now past b2, we will
// need to get back the entry block.
// This is what happens with unstructured control-flow..
@@ -226,24 +228,25 @@ TEST_FUNC(builder_blocks_eager) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
arg4(c1.getType()), r(c1.getType());
Value c1(std_constant_int(42, 32)), c2(std_constant_int(1234, 32));
Value res;
Value args1And2[2], args3And4[2];
Value &arg1 = args1And2[0], &arg2 = args1And2[1], &arg3 = args3And4[0],
&arg4 = args3And4[1];
// clang-format off
BlockHandle b1, b2;
{ // Toplevel function scope.
// Build a new block for b1 eagerly.
std_br(&b1, {&arg1, &arg2}, {c1, c2});
std_br(&b1, {c1.getType(), c1.getType()}, args1And2, {c1, c2});
// Construct a new block b2 explicitly with a branch into b1.
BlockBuilder(&b2, {&arg3, &arg4})([&]{
BlockBuilder(&b2, {c1.getType(), c1.getType()}, args3And4)([&]{
std_br(b1, {arg3, arg4});
});
/// And come back to append into b1 once b2 exists.
BlockBuilder(b1, Append())([&]{
r = arg1 + arg2;
std_br(b2, {arg1, r});
res = arg1 + arg2;
std_br(b2, {arg1, res});
});
}
@@ -268,15 +271,14 @@ TEST_FUNC(builder_cond_branch) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle funcArg(f.getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
c42(ValueHandle::create<ConstantIntOp>(42, 32));
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
Value funcArg(f.getArgument(0));
Value c32(std_constant_int(32, 32)), c64(std_constant_int(64, 64)),
c42(std_constant_int(42, 32));
Value arg1;
Value args23[2];
BlockHandle b1, b2, functionBlock(&f.front());
BlockBuilder(&b1, {&arg1})([&] { std_ret(); });
BlockBuilder(&b2, {&arg2, &arg3})([&] { std_ret(); });
BlockBuilder(&b1, c32.getType(), arg1)([&] { std_ret(); });
BlockBuilder(&b2, {c64.getType(), c32.getType()}, args23)([&] { std_ret(); });
// Get back to entry block and add a conditional branch
BlockBuilder(functionBlock, Append())([&] {
std_cond_br(funcArg, b1, {c32}, b2, {c64, c42});
@@ -304,15 +306,16 @@ TEST_FUNC(builder_cond_branch_eager) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle funcArg(f.getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
c42(ValueHandle::create<ConstantIntOp>(42, 32));
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
Value arg0(f.getArgument(0));
Value c32(std_constant_int(32, 32)), c64(std_constant_int(64, 64)),
c42(std_constant_int(42, 32));
// clang-format off
BlockHandle b1, b2;
std_cond_br(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42});
Value arg1[1], args2And3[2];
std_cond_br(arg0,
&b1, c32.getType(), arg1, c32,
&b2, {c64.getType(), c32.getType()}, args2And3, {c64, c42});
BlockBuilder(b1, Append())([]{
std_ret();
});
@@ -336,7 +339,6 @@ TEST_FUNC(builder_cond_branch_eager) {
TEST_FUNC(builder_helpers) {
using namespace edsc::op;
auto indexType = IndexType::get(&globalContext());
auto f32Type = FloatType::getF32(&globalContext());
auto memrefType =
MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
@@ -348,21 +350,20 @@ TEST_FUNC(builder_helpers) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle f7(
ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type));
Value f7 = std_constant_float(llvm::APFloat(7.0f), f32Type);
MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)),
vC(f.getArgument(2));
AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
ValueHandle i(indexType), j(indexType), k1(indexType), k2(indexType),
lb0(indexType), lb1(indexType), lb2(indexType),
ub0(indexType), ub1(indexType), ub2(indexType);
Value ivs[2];
Value &i = ivs[0], &j = ivs[1];
Value k1, k2, lb0, lb1, lb2, ub0, ub1, ub2;
int64_t step0, step1, step2;
std::tie(lb0, ub0, step0) = vA.range(0);
std::tie(lb1, ub1, step1) = vA.range(1);
lb2 = vA.lb(2);
ub2 = vA.ub(2);
step2 = vA.step(2);
AffineLoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})([&]{
AffineLoopNestBuilder(ivs, {lb0, lb1}, {ub0, ub1}, {step0, step1})([&]{
AffineLoopNestBuilder(&k1, lb2, ub2, step2)([&]{
C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1);
});
@@ -393,45 +394,6 @@ TEST_FUNC(builder_helpers) {
f.erase();
}
TEST_FUNC(custom_ops) {
using namespace edsc::op;
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("custom_ops", {}, {indexType, indexType});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
CustomOperation<ValueHandle> MY_CUSTOM_OP("my_custom_op");
CustomOperation<OperationHandle> MY_CUSTOM_OP_0("my_custom_op_0");
CustomOperation<OperationHandle> MY_CUSTOM_OP_2("my_custom_op_2");
// clang-format off
ValueHandle vh(indexType), vh20(indexType), vh21(indexType);
OperationHandle ih0, ih2;
ValueHandle m(indexType), n(indexType);
ValueHandle M(f.getArgument(0)), N(f.getArgument(1));
ValueHandle ten(std_constant_index(10)), twenty(std_constant_index(20));
AffineLoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {});
ih0 = MY_CUSTOM_OP_0({m, m + n}, {});
ih2 = MY_CUSTOM_OP_2({m, m + n}, {indexType, indexType});
// These captures are verbose for now, can improve when used in practice.
vh20 = ValueHandle(ih2.getOperation()->getResult(0));
vh21 = ValueHandle(ih2.getOperation()->getResult(1));
MY_CUSTOM_OP({vh20, vh21}, {indexType}, {});
});
// CHECK-LABEL: @custom_ops
// CHECK: affine.for %{{.*}} {{.*}}
// CHECK: affine.for %{{.*}} {{.*}}
// CHECK: {{.*}} = "my_custom_op"{{.*}} : (index, index) -> index
// CHECK: "my_custom_op_0"{{.*}} : (index, index) -> ()
// CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index)
// CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index
// clang-format on
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(insertion_in_block) {
using namespace edsc::op;
auto indexType = IndexType::get(&globalContext());
@@ -441,11 +403,11 @@ TEST_FUNC(insertion_in_block) {
ScopedContext scope(builder, f.getLoc());
BlockHandle b1;
// clang-format off
ValueHandle::create<ConstantIntOp>(0, 32);
BlockBuilder(&b1, {})([]{
ValueHandle::create<ConstantIntOp>(1, 32);
std_constant_int(0, 32);
(BlockBuilder(&b1))([]{
std_constant_int(1, 32);
});
ValueHandle::create<ConstantIntOp>(2, 32);
std_constant_int(2, 32);
// CHECK-LABEL: @insertion_in_block
// CHECK: {{.*}} = constant 0 : i32
// CHECK: {{.*}} = constant 2 : i32
@@ -469,8 +431,8 @@ TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
AffineIndexedValue A(f.getArgument(0));
AffineIndexedValue B(f.getArgument(1));
// clang-format off
edsc::intrinsics::std_zero_extendi(*A, i8Type);
edsc::intrinsics::std_sign_extendi(*B, i8Type);
edsc::intrinsics::std_zero_extendi(A, i8Type);
edsc::intrinsics::std_sign_extendi(B, i8Type);
// CHECK-LABEL: @zero_and_std_sign_extendi_op
// CHECK: %[[SRC1:.*]] = affine.load
// CHECK: zexti %[[SRC1]] : i1 to i8
@@ -489,8 +451,8 @@ TEST_FUNC(operator_or) {
ScopedContext scope(builder, f.getLoc());
using op::operator||;
ValueHandle lhs(f.getArgument(0));
ValueHandle rhs(f.getArgument(1));
Value lhs(f.getArgument(0));
Value rhs(f.getArgument(1));
lhs || rhs;
// CHECK-LABEL: @operator_or
@@ -508,8 +470,8 @@ TEST_FUNC(operator_and) {
ScopedContext scope(builder, f.getLoc());
using op::operator&&;
ValueHandle lhs(f.getArgument(0));
ValueHandle rhs(f.getArgument(1));
Value lhs(f.getArgument(0));
Value rhs(f.getArgument(1));
lhs &&rhs;
// CHECK-LABEL: @operator_and
@@ -521,7 +483,6 @@ TEST_FUNC(operator_and) {
TEST_FUNC(select_op_i32) {
using namespace edsc::op;
auto indexType = IndexType::get(&globalContext());
auto f32Type = FloatType::getF32(&globalContext());
auto memrefType = MemRefType::get(
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
@@ -530,17 +491,13 @@ TEST_FUNC(select_op_i32) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle zero = std_constant_index(0), one = std_constant_index(1);
Value zero = std_constant_index(0), one = std_constant_index(1);
MemRefBoundsCapture vA(f.getArgument(0));
AffineIndexedValue A(f.getArgument(0));
ValueHandle i(indexType), j(indexType);
AffineLoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
// This test exercises AffineIndexedValue::operator Value.
// Without it, one must force conversion to ValueHandle as such:
// std_select(
// i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j)))
using edsc::op::operator==;
std_select(i == zero, *A(zero, zero), *A(i, j));
Value ivs[2];
Value &i = ivs[0], &j = ivs[1];
AffineLoopNestBuilder(ivs, {zero, zero}, {one, one}, {1, 1})([&]{
std_select(eq(i, zero), A(zero, zero), A(i, j));
});
// CHECK-LABEL: @select_op
@@ -556,7 +513,6 @@ TEST_FUNC(select_op_i32) {
}
TEST_FUNC(select_op_f32) {
auto indexType = IndexType::get(&globalContext());
auto f32Type = FloatType::getF32(&globalContext());
auto memrefType = MemRefType::get(
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
@@ -565,18 +521,19 @@ TEST_FUNC(select_op_f32) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle zero = std_constant_index(0), one = std_constant_index(1);
Value zero = std_constant_index(0), one = std_constant_index(1);
MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1));
AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1));
ValueHandle i(indexType), j(indexType);
AffineLoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
Value ivs[2];
Value &i = ivs[0], &j = ivs[1];
AffineLoopNestBuilder(ivs, {zero, zero}, {one, one}, {1, 1})([&]{
using namespace edsc::op;
std_select(B(i, j) == B(i + one, j), *A(zero, zero), *A(i, j));
std_select(B(i, j) != B(i + one, j), *A(zero, zero), *A(i, j));
std_select(B(i, j) >= B(i + one, j), *A(zero, zero), *A(i, j));
std_select(B(i, j) <= B(i + one, j), *A(zero, zero), *A(i, j));
std_select(B(i, j) < B(i + one, j), *A(zero, zero), *A(i, j));
std_select(B(i, j) > B(i + one, j), *A(zero, zero), *A(i, j));
std_select(eq(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(ne(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(B(i, j) >= B(i + one, j), A(zero, zero), A(i, j));
std_select(B(i, j) <= B(i + one, j), A(zero, zero), A(i, j));
std_select(B(i, j) < B(i + one, j), A(zero, zero), A(i, j));
std_select(B(i, j) > B(i + one, j), A(zero, zero), A(i, j));
});
// CHECK-LABEL: @select_op
@@ -632,7 +589,6 @@ TEST_FUNC(select_op_f32) {
// Inject an EDSC-constructed computation to exercise imperfectly nested 2-d
// tiling.
TEST_FUNC(tile_2d) {
auto indexType = IndexType::get(&globalContext());
auto memrefType =
MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
ShapedType::kDynamicSize},
@@ -641,17 +597,19 @@ TEST_FUNC(tile_2d) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle zero = std_constant_index(0);
Value zero = std_constant_index(0);
MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)),
vC(f.getArgument(2));
AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)),
C(f.getArgument(2));
ValueHandle i(indexType), j(indexType), k1(indexType), k2(indexType);
ValueHandle M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
Value ivs[2];
Value &i = ivs[0], &j = ivs[1];
Value k1, k2;
Value M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
// clang-format off
using namespace edsc::op;
AffineLoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
AffineLoopNestBuilder(ivs, {zero, zero}, {M, N}, {1, 1})([&]{
AffineLoopNestBuilder(&k1, zero, O, 1)([&]{
C(i, j, k1) = A(i, j, k1) + B(i, j, k1);
});
@@ -661,10 +619,8 @@ TEST_FUNC(tile_2d) {
});
// clang-format on
auto li = getForInductionVarOwner(i.getValue()),
lj = getForInductionVarOwner(j.getValue()),
lk1 = getForInductionVarOwner(k1.getValue()),
lk2 = getForInductionVarOwner(k2.getValue());
auto li = getForInductionVarOwner(i), lj = getForInductionVarOwner(j),
lk1 = getForInductionVarOwner(k1), lk2 = getForInductionVarOwner(k2);
auto indicesL1 = mlir::tile({li, lj}, {512, 1024}, {lk1, lk2});
auto lii1 = indicesL1[0][0], ljj1 = indicesL1[1][0];
mlir::tile({ljj1, lii1}, {32, 16}, ljj1);
@@ -713,15 +669,15 @@ TEST_FUNC(indirect_access) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle zero = std_constant_index(0);
Value zero = std_constant_index(0);
MemRefBoundsCapture vC(f.getArgument(2));
AffineIndexedValue B(f.getArgument(1)), D(f.getArgument(3));
StdIndexedValue A(f.getArgument(0)), C(f.getArgument(2));
ValueHandle i(builder.getIndexType()), N(vC.ub(0));
Value i, N(vC.ub(0));
// clang-format off
AffineLoopNestBuilder(&i, zero, N, 1)([&]{
C((ValueHandle)D(i)) = A((ValueHandle)B(i));
C((Value)D(i)) = A((Value)B(i));
});
// clang-format on
@@ -747,12 +703,12 @@ TEST_FUNC(empty_map_load_store) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle zero = std_constant_index(0);
ValueHandle one = std_constant_index(1);
Value zero = std_constant_index(0);
Value one = std_constant_index(1);
AffineIndexedValue input(f.getArgument(0)), res(f.getArgument(1));
ValueHandle iv(builder.getIndexType());
// clang-format off
Value iv;
AffineLoopNestBuilder(&iv, zero, one, 1)([&]{
res() = input();
});
@@ -784,7 +740,7 @@ TEST_FUNC(affine_if_op) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle zero = std_constant_index(0), ten = std_constant_index(10);
Value zero = std_constant_index(0), ten = std_constant_index(10);
SmallVector<bool, 4> isEq = {false, false, false, false};
SmallVector<AffineExpr, 4> affineExprs = {
@@ -834,7 +790,7 @@ TEST_FUNC(linalg_generic_pointwise_test) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
Value A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
AffineExpr i, j;
bindDims(&globalContext(), i, j);
StructuredIndexed SA(A), SB(B), SC(C);
@@ -864,12 +820,12 @@ TEST_FUNC(linalg_generic_matmul_test) {
auto f32Type = FloatType::getF32(&globalContext());
auto memrefType = MemRefType::get(
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
auto f =
makeFunction("linalg_generic_matmul", {}, {memrefType, memrefType, memrefType});
auto f = makeFunction("linalg_generic_matmul", {},
{memrefType, memrefType, memrefType});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
linalg_generic_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())));
linalg_generic_matmul(f.getArguments());
f.print(llvm::outs());
f.erase();
@@ -902,8 +858,8 @@ TEST_FUNC(linalg_generic_conv_nhwc) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
linalg_generic_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())),
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
linalg_generic_conv_nhwc(f.getArguments(),
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
f.print(llvm::outs());
f.erase();
@@ -936,9 +892,9 @@ TEST_FUNC(linalg_generic_dilated_conv_nhwc) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
linalg_generic_dilated_conv_nhwc(makeValueHandles(f.getArguments()),
/*depth_multiplier=*/7,
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
linalg_generic_dilated_conv_nhwc(f.getArguments(),
/*depth_multiplier=*/7,
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
f.print(llvm::outs());
f.erase();
@@ -958,7 +914,7 @@ TEST_FUNC(linalg_metadata_ops) {
ScopedContext scope(builder, f.getLoc());
AffineExpr i, j, k;
bindDims(&globalContext(), i, j, k);
ValueHandle v(f.getArgument(0));
Value v(f.getArgument(0));
auto reshaped = linalg_reshape(v, ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
linalg_reshape(memrefType, reshaped,
ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
@@ -1015,7 +971,7 @@ TEST_FUNC(linalg_tensors_test) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle A(f.getArgument(0)), B(f.getArgument(1));
Value A(f.getArgument(0)), B(f.getArgument(1));
AffineExpr i, j;
bindDims(&globalContext(), i, j);
StructuredIndexed SA(A), SB(B), SC(tensorType);
@@ -1023,7 +979,7 @@ TEST_FUNC(linalg_tensors_test) {
linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j}));
Value o1 = linalg_generic_matmul(A, B, tensorType)->getResult(0);
linalg_generic_matmul(A, B, ValueHandle(o1), tensorType);
linalg_generic_matmul(A, B, o1, tensorType);
f.print(llvm::outs());
f.erase();
@@ -1064,7 +1020,7 @@ TEST_FUNC(memref_vector_matmul_test) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
Value A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
auto contractionBuilder = [](ArrayRef<BlockArgument> args) {
assert(args.size() == 3 && "expected 3 block arguments");
(linalg_yield(vector_contraction_matmul(args[0], args[1], args[2])));
@@ -1083,19 +1039,19 @@ TEST_FUNC(builder_loop_for_yield) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle init0 = std_constant_float(llvm::APFloat(1.0f), f32Type);
ValueHandle init1 = std_constant_float(llvm::APFloat(2.0f), f32Type);
ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
c(f.getArgument(2)), d(f.getArgument(3));
ValueHandle arg0(f32Type);
ValueHandle arg1(f32Type);
Value init0 = std_constant_float(llvm::APFloat(1.0f), f32Type);
Value init1 = std_constant_float(llvm::APFloat(2.0f), f32Type);
Value i, a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)),
d(f.getArgument(3));
Value args01[2];
Value &arg0 = args01[0], &arg1 = args01[1];
using namespace edsc::op;
auto results =
LoopNestBuilder(&i, a - b, c + d, a, {&arg0, &arg1}, {init0, init1})([&] {
LoopNestBuilder(&i, a - b, c + d, a, args01, {init0, init1})([&] {
auto sum = arg0 + arg1;
loop_yield(ArrayRef<ValueHandle>{arg1, sum});
loop_yield(ArrayRef<Value>{arg1, sum});
});
ValueHandle(results[0]) + ValueHandle(results[1]);
results[0] + results[1];
// clang-format off
// CHECK-LABEL: func @builder_loop_for_yield(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {

View File

@@ -16,9 +16,9 @@
// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) };
//
// IMPL: Test1Op::regionBuilder(Block &block) {
// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
// IMPL: (linalg_yield(ValueRange{ [[e]] }));
//
ods_def<Test1Op> :
@@ -41,9 +41,9 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) };
//
// IMPL: Test2Op::regionBuilder(Block &block) {
// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
// IMPL: (linalg_yield(ValueRange{ [[e]] }));
//
ods_def<Test2Op> :
@@ -66,9 +66,9 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) };
//
// IMPL: Test3Op::regionBuilder(Block &block) {
// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
// IMPL: (linalg_yield(ValueRange{ [[e]] }));
//
ods_def<Test3Op> :

View File

@@ -1601,7 +1601,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
printExpr(subExprsStringStream, *e);
});
subExprsStringStream.flush();
const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});";
const char *tensorExprFmt = "\n Value _{0} = {1}({2});";
os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName,
subExprs);
subExprsMap[pTensorExpr] = count;
@@ -1613,7 +1613,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
using namespace edsc;
using namespace intrinsics;
auto args = block.getArguments();
ValueHandle {1};
Value {1};
{2}
(linalg_yield(ValueRange{ {3} }));
})FMT";