diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index c634505d3474..10dc81bcd08b 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -26,6 +26,7 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/StandardOps/Ops.h" +#include "mlir/SuperVectorOps/SuperVectorOps.h" namespace mlir { @@ -291,6 +292,14 @@ public: /// assigning. ValueHandle &operator=(const ValueHandle &other); + /// Provide a swap operator. + void swap(ValueHandle &other) { + if (this == &other) + return; + std::swap(t, other.t); + std::swap(v, other.v); + } + /// Implicit conversion useful for automatic conversion to Container. operator Value *() const { return getValue(); } @@ -310,11 +319,14 @@ public: ArrayRef attributes = {}); bool hasValue() const { return v != nullptr; } - Value *getValue() const { return v; } + Value *getValue() const { + assert(hasValue() && "Unexpected null value;"); + return v; + } bool hasType() const { return t != Type(); } Type getType() const { return t; } -private: +protected: ValueHandle() : t(), v(nullptr) {} Type t; @@ -442,6 +454,12 @@ ValueHandle operator!(ValueHandle value); ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs); ValueHandle operator||(ValueHandle lhs, ValueHandle rhs); ValueHandle operator^(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator==(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator<(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator>(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs); } // namespace op } // namespace edsc diff --git a/mlir/include/mlir/EDSC/Helpers.h b/mlir/include/mlir/EDSC/Helpers.h index d653c7f4d359..660b957b5068 100644 --- a/mlir/include/mlir/EDSC/Helpers.h +++ b/mlir/include/mlir/EDSC/Helpers.h @@ -25,6 +25,8 @@ #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" +#include "llvm/Support/raw_ostream.h" + namespace mlir { namespace edsc { @@ -45,7 +47,48 @@ struct IndexHandle : public ValueHandle { assert(v->getType() == ScopedContext::getBuilder()->getIndexType() && "Expected index type"); } - explicit IndexHandle(ValueHandle v) : ValueHandle(v) {} + explicit IndexHandle(ValueHandle v) : ValueHandle(v) { + assert(v.getType() == ScopedContext::getBuilder()->getIndexType() && + "Expected index type"); + } + IndexHandle &operator=(const ValueHandle &v) { + assert(v.getType() == ScopedContext::getBuilder()->getIndexType() && + "Expected index type"); + /// Creating a new IndexHandle(v) and then std::swap rightly complains the + /// binding has already occurred and that we should use another name. + this->t = v.getType(); + this->v = v.getValue(); + return *this; + } +}; + +// Base class for MemRefView and VectorView. +class View { +public: + unsigned rank() const { return lbs.size(); } + ValueHandle lb(unsigned idx) { return lbs[idx]; } + ValueHandle ub(unsigned idx) { return ubs[idx]; } + int64_t step(unsigned idx) { return steps[idx]; } + std::tuple range(unsigned idx) { + return std::make_tuple(lbs[idx], ubs[idx], steps[idx]); + } + void swapRanges(unsigned i, unsigned j) { + llvm::errs() << "\nSWAP: " << i << " " << j; + if (i == j) + return; + lbs[i].swap(lbs[j]); + ubs[i].swap(ubs[j]); + std::swap(steps[i], steps[j]); + } + + ArrayRef getLbs() { return lbs; } + ArrayRef getUbs() { return ubs; } + ArrayRef getSteps() { return steps; } + +protected: + SmallVector lbs; + SmallVector ubs; + SmallVector steps; }; /// A MemRefView represents the information required to step through a @@ -53,35 +96,32 @@ struct IndexHandle : public ValueHandle { /// Fortran subarray model. /// At the moment it can only capture a MemRef with an identity layout map. // TODO(ntv): Support MemRefs with layoutMaps. -class MemRefView { +class MemRefView : public View { public: explicit MemRefView(Value *v); MemRefView(const MemRefView &) = default; MemRefView &operator=(const MemRefView &) = default; - unsigned rank() const { return lbs.size(); } unsigned fastestVarying() const { return rank() - 1; } - IndexHandle lb(unsigned idx) { return lbs[idx]; } - IndexHandle ub(unsigned idx) { return ubs[idx]; } - int64_t step(unsigned idx) { return steps[idx]; } - std::tuple range(unsigned idx) { - return std::make_tuple(lbs[idx], ubs[idx], steps[idx]); - } - private: friend IndexedValue; - ValueHandle base; - SmallVector lbs; - SmallVector ubs; - SmallVector steps; }; -ValueHandle operator+(ValueHandle v, IndexedValue i); -ValueHandle operator-(ValueHandle v, IndexedValue i); -ValueHandle operator*(ValueHandle v, IndexedValue i); -ValueHandle operator/(ValueHandle v, IndexedValue i); +/// A VectorView represents the information required to step through a +/// Vector accessing each scalar element at a time. It is the counterpart of +/// a MemRefView but for vectors. This exists purely for boilerplate avoidance. +class VectorView : public View { +public: + explicit VectorView(Value *v); + VectorView(const VectorView &) = default; + VectorView &operator=(const VectorView &) = default; + +private: + friend IndexedValue; + ValueHandle base; +}; /// This helper class is an abstraction over memref, that purely for sugaring /// purposes and allows writing compact expressions such as: @@ -97,32 +137,54 @@ ValueHandle operator/(ValueHandle v, IndexedValue i); /// converting an IndexedValue to a ValueHandle emits an actual load operation. struct IndexedValue { explicit IndexedValue(Type t) : base(t) {} - explicit IndexedValue(Value *v, llvm::ArrayRef indices = {}) - : IndexedValue(ValueHandle(v), indices) {} - explicit IndexedValue(ValueHandle v, llvm::ArrayRef indices = {}) - : base(v), indices(indices.begin(), indices.end()) {} + explicit IndexedValue(Value *v) : IndexedValue(ValueHandle(v)) {} + explicit IndexedValue(ValueHandle v) : base(v) {} IndexedValue(const IndexedValue &rhs) = default; - IndexedValue &operator=(const IndexedValue &rhs) = default; + ValueHandle operator()() { return ValueHandle(*this); } /// Returns a new `IndexedValue`. - IndexedValue operator()(llvm::ArrayRef indices = {}) { + IndexedValue operator()(ValueHandle index) { + IndexedValue res(base); + res.indices.push_back(index); + return res; + } + template + IndexedValue operator()(ValueHandle index, Args... indices) { + return IndexedValue(base, index).append(indices...); + } + IndexedValue operator()(llvm::ArrayRef indices) { return IndexedValue(base, indices); } + IndexedValue operator()(llvm::ArrayRef indices) { + return IndexedValue( + base, llvm::ArrayRef(indices.begin(), indices.end())); + } /// Emits a `store`. // NOLINTNEXTLINE: unconventional-assign-operator + InstructionHandle operator=(const IndexedValue &rhs) { + ValueHandle rrhs(rhs); + assert(getBase().getType().cast().getRank() == indices.size() && + "Unexpected number of indices to store in MemRef"); + return intrinsics::STORE(rrhs, getBase(), indices); + } + // NOLINTNEXTLINE: unconventional-assign-operator InstructionHandle operator=(ValueHandle rhs) { + assert(getBase().getType().cast().getRank() == indices.size() && + "Unexpected number of indices to store in MemRef"); return intrinsics::STORE(rhs, getBase(), indices); } - ValueHandle getBase() const { return base; } - /// Emits a `load` when converting to a ValueHandle. - explicit operator ValueHandle() { + operator ValueHandle() const { + assert(getBase().getType().cast().getRank() == indices.size() && + "Unexpected number of indices to store in MemRef"); return intrinsics::LOAD(getBase(), indices); } + ValueHandle getBase() const { return base; } + /// Operator overloadings. ValueHandle operator+(ValueHandle e); ValueHandle operator-(ValueHandle e); @@ -158,6 +220,17 @@ struct IndexedValue { } private: + IndexedValue(ValueHandle base, ArrayRef indices) + : base(base), indices(indices.begin(), indices.end()) {} + + IndexedValue &append() { return *this; } + + template + IndexedValue &append(T index, Args... indices) { + this->indices.push_back(static_cast(index)); + append(indices...); + return *this; + } ValueHandle base; llvm::SmallVector indices; }; diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index cbfe43efc422..596119fca508 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -23,10 +23,14 @@ #ifndef MLIR_EDSC_INTRINSICS_H_ #define MLIR_EDSC_INTRINSICS_H_ +#include "mlir/EDSC/Builders.h" #include "mlir/Support/LLVM.h" namespace mlir { +class MemRefType; +class Type; + namespace edsc { class BlockHandle; @@ -94,9 +98,34 @@ InstructionHandle COND_BR(ValueHandle cond, BlockHandle *trueBranch, ArrayRef falseCaptures, ArrayRef falseOperands); -//////////////////////////////////////////////////////////////////////////////// -// TODO(ntv): Intrinsics below this line should be TableGen'd. -//////////////////////////////////////////////////////////////////////////////// +/// Helper variadic abstraction to allow extending to any MLIR op without +/// boilerplate or Tablegen. +/// Arguably a builder is not a ValueHandle but in practice it is only used as +/// an alias to a notional ValueHandle. +/// Implementing it as a subclass allows it to compose all the way to Value*. +/// Without subclassing, implicit conversion to Value* would fail when composing +/// in patterns such as: `select(a, b, select(c, d, e))`. +template struct EDSCValueBuilder : public ValueHandle { + template + EDSCValueBuilder(Args... args) + : ValueHandle(ValueHandle::create(std::forward(args)...)) {} + EDSCValueBuilder() = delete; +}; + +template +struct EDSCInstructionBuilder : public InstructionHandle { + template + EDSCInstructionBuilder(Args... args) + : InstructionHandle( + InstructionHandle::create(std::forward(args)...)) {} + EDSCInstructionBuilder() = delete; +}; + +using alloc = EDSCValueBuilder; +using dealloc = EDSCInstructionBuilder; +using select = EDSCValueBuilder; +using vector_type_cast = EDSCValueBuilder; + /// Builds an mlir::LoadOp with the proper `operands` that each must have /// captured an mlir::Value*. /// Returns a ValueHandle to the produced mlir::Value*. @@ -104,19 +133,16 @@ ValueHandle LOAD(ValueHandle base, llvm::ArrayRef indices); /// Builds an mlir::ReturnOp with the proper `operands` that each must have /// captured an mlir::Value*. -/// Returns an empty ValueHandle. +/// Returns an InstructionHandle. InstructionHandle RETURN(llvm::ArrayRef operands); /// Builds an mlir::StoreOp with the proper `operands` that each must have /// captured an mlir::Value*. -/// Returns an empty ValueHandle. +/// Returns an InstructionHandle. InstructionHandle STORE(ValueHandle value, ValueHandle base, llvm::ArrayRef indices); - } // namespace intrinsics - } // namespace edsc - } // namespace mlir #endif // MLIR_EDSC_INTRINSICS_H_ diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 82399af26b5f..b3cfad61cc0e 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -381,3 +381,36 @@ ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) { ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) { return !(!lhs && !rhs); } + +static ValueHandle createComparisonExpr(CmpIPredicate predicate, + ValueHandle lhs, ValueHandle rhs) { + auto lhsType = lhs.getType(); + auto rhsType = rhs.getType(); + assert(lhsType == rhsType && "cannot mix types in operators"); + assert((lhsType.isa() || lhsType.isa()) && + "only integer comparisons are supported"); + + auto op = ScopedContext::getBuilder()->create( + ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); + return ValueHandle(op->getResult()); +} + +ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) { + return createComparisonExpr(CmpIPredicate::EQ, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) { + return createComparisonExpr(CmpIPredicate::NE, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) { + // TODO(ntv,zinenko): signed by default, how about unsigned? + return createComparisonExpr(CmpIPredicate::SLT, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) { + return createComparisonExpr(CmpIPredicate::SLE, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) { + return createComparisonExpr(CmpIPredicate::SGT, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) { + return createComparisonExpr(CmpIPredicate::SGE, lhs, rhs); +} diff --git a/mlir/lib/EDSC/Helpers.cpp b/mlir/lib/EDSC/Helpers.cpp index 6400c73e2295..72eaf00634f6 100644 --- a/mlir/lib/EDSC/Helpers.cpp +++ b/mlir/lib/EDSC/Helpers.cpp @@ -22,21 +22,21 @@ using namespace mlir; using namespace mlir::edsc; -static SmallVector getMemRefSizes(Value *memRef) { +static SmallVector getMemRefSizes(Value *memRef) { MemRefType memRefType = memRef->getType().cast(); auto maps = memRefType.getAffineMaps(); (void)maps; assert((maps.empty() || (maps.size() == 1 && maps[0].isIdentity())) && "Layout maps not supported"); - SmallVector res; + SmallVector res; res.reserve(memRefType.getShape().size()); const auto &shape = memRefType.getShape(); for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) { if (shape[idx] == -1) { - res.push_back(IndexHandle(ValueHandle::create(memRef, idx))); + res.push_back(ValueHandle::create(memRef, idx)); } else { - res.push_back(IndexHandle(static_cast(shape[idx]))); + res.push_back(static_cast(shape[idx])); } } return res; @@ -47,12 +47,22 @@ mlir::edsc::MemRefView::MemRefView(Value *v) : base(v) { auto memrefSizeValues = getMemRefSizes(v); for (auto &size : memrefSizeValues) { - lbs.push_back(IndexHandle(static_cast(0))); + lbs.push_back(static_cast(0)); ubs.push_back(size); steps.push_back(1); } } +mlir::edsc::VectorView::VectorView(Value *v) : base(v) { + auto vectorType = v->getType().cast(); + + for (auto s : vectorType.getShape()) { + lbs.push_back(static_cast(0)); + ubs.push_back(static_cast(s)); + steps.push_back(1); + } +} + /// Operator overloadings. ValueHandle mlir::edsc::IndexedValue::operator+(ValueHandle e) { using op::operator+; @@ -87,23 +97,3 @@ InstructionHandle mlir::edsc::IndexedValue::operator/=(ValueHandle e) { using op::operator/; return intrinsics::STORE(*this / e, getBase(), indices); } - -ValueHandle mlir::edsc::operator+(ValueHandle v, IndexedValue i) { - using op::operator+; - return v + static_cast(i); -} - -ValueHandle mlir::edsc::operator-(ValueHandle v, IndexedValue i) { - using op::operator-; - return v - static_cast(i); -} - -ValueHandle mlir::edsc::operator*(ValueHandle v, IndexedValue i) { - using op::operator*; - return v * static_cast(i); -} - -ValueHandle mlir::edsc::operator/(ValueHandle v, IndexedValue i) { - using op::operator/; - return v / static_cast(i); -} diff --git a/mlir/lib/EDSC/Intrinsics.cpp b/mlir/lib/EDSC/Intrinsics.cpp index 2bfb8a7fc026..887561437ebe 100644 --- a/mlir/lib/EDSC/Intrinsics.cpp +++ b/mlir/lib/EDSC/Intrinsics.cpp @@ -18,6 +18,7 @@ #include "mlir/EDSC/Intrinsics.h" #include "mlir/EDSC/Builders.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/SuperVectorOps/SuperVectorOps.h" using namespace mlir; using namespace mlir::edsc; diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 9ac8583bc78c..395489551a8e 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -25,7 +25,8 @@ #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Analysis/Utils.h" #include "mlir/Analysis/VectorAnalysis.h" -#include "mlir/EDSC/MLIREmitter.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -43,82 +44,14 @@ #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SetVector.h" -#include "llvm/Support/Allocator.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" /// /// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a /// proper abstraction for the hardware. /// -/// For now only a simple loop nest is emitted. +/// For now, we only emit a simple loop nest that performs clipped pointwise +/// copies from a remote to a locally allocated memory. /// - -using llvm::dbgs; -using llvm::SetVector; - -using namespace mlir; - -#define DEBUG_TYPE "lower-vector-transfers" - -namespace { -/// Helper structure to hold information about loop nest, clipped accesses to -/// the original scalar MemRef as well as full accesses to temporary MemRef in -/// local storage. -struct VectorTransferAccessInfo { - // `ivs` are bound for `For` Stmt at `For` Stmt construction time. - llvm::SmallVector ivs; - llvm::SmallVector lowerBoundsExprs; - llvm::SmallVector upperBoundsExprs; - llvm::SmallVector stepExprs; - llvm::SmallVector clippedScalarAccessExprs; - llvm::SmallVector tmpAccessExprs; -}; - -template class VectorTransferRewriter { -public: - /// Perform the rewrite using the `emitter`. - VectorTransferRewriter(VectorTransferOpTy *transfer, - MLFuncLoweringRewriter *rewriter, - MLFuncGlobalLoweringState *state); - - /// Perform the rewrite using the `emitter`. - void rewrite(); - - /// Helper class which creates clipped memref accesses to support lowering of - /// the vector_transfer operation. - VectorTransferAccessInfo makeVectorTransferAccessInfo(); - -private: - VectorTransferOpTy *transfer; - MLFuncLoweringRewriter *rewriter; - MLFuncGlobalLoweringState *state; - - MemRefType memrefType; - ArrayRef memrefShape; - VectorType vectorType; - ArrayRef vectorShape; - AffineMap permutationMap; - - /// Used for staging the transfer in a local scalar buffer. - MemRefType tmpMemRefType; - /// View of tmpMemRefType as one vector, used in vector load/store to tmp - /// buffer. - MemRefType vectorMemRefType; - - // EDSC `emitter` and Exprs that are pre-bound at construction time. - edsc::MLIREmitter emitter; - // vectorSizes are bound to the actual constant sizes of vectorType. - llvm::SmallVector vectorSizes; - // accesses are bound to transfer->getIndices() - llvm::SmallVector accesses; - // `zero` and `one` are bound emitter.zero() and emitter.one(). - // `scalarMemRef` is bound to `transfer->getMemRef()`. - edsc::Expr zero, one, scalarMemRef; -}; - -} // end anonymous namespace - /// Consider the case: /// /// ```mlir {.mlir} @@ -134,141 +67,160 @@ private: /// }}} /// ``` /// -/// The following constructs the `loadAccessExpr` that supports the emission of -/// MLIR resembling: +/// The rewriters construct loop and indices that access MemRef A in a pattern +/// resembling the following (while guaranteeing an always full-tile +/// abstraction): /// -/// ```mlir -/// for %d1 = 0 to 256 { -/// for %d2 = 0 to 32 { +/// ```mlir {.mlir} +/// for %d2 = 0 to 256 { +/// for %d1 = 0 to 32 { /// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32 /// %tmp[%d2, %d1] = %s /// } /// } /// ``` /// -/// Notice in particular the order of loops iterating over the vector size -/// (i.e. 256x32 instead of 32x256). This results in contiguous accesses along -/// the most minor dimension of the original scalar tensor. On many hardware -/// architectures this will result in better utilization of the underlying -/// memory subsystem (e.g. prefetchers, DMAs, #memory transactions, etc...). -/// -/// This additionally performs clipping as described in -/// `VectorTransferRewriter::rewrite` by emitting: +/// In the current state, only a clipping transfer is implemented by `clip`, +/// which creates individual indexing expressions of the form: /// /// ```mlir-dsc -/// select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one)) +/// SELECT(i + ii < zero, zero, SELECT(i + ii < N, i + ii, N - one)) /// ``` -template -VectorTransferAccessInfo -VectorTransferRewriter::makeVectorTransferAccessInfo() { - using namespace mlir::edsc; - using namespace edsc::op; - // Create new Exprs for ivs, they will be bound at `For` Stmt - // construction. - auto ivs = makeNewExprs(vectorShape.size(), this->rewriter->getIndexType()); +using namespace mlir; - // Create and bind Exprs to refer to the Value for memref sizes. - auto memRefSizes = emitter.makeBoundMemRefShape(transfer->getMemRef()); +#define DEBUG_TYPE "lower-vector-transfers" - // Create the edsc::Expr for the clipped and transposes access expressions - // using the permutationMap. Additionally, capture the index accessing the - // most minor dimension. - int coalescingIndex = -1; - auto clippedScalarAccessExprs = copyExprs(accesses); - auto tmpAccessExprs = copyExprs(ivs); - llvm::DenseSet clipped; - for (auto it : llvm::enumerate(permutationMap.getResults())) { - if (auto affineExpr = it.value().template dyn_cast()) { - auto pos = affineExpr.getPosition(); - auto i = clippedScalarAccessExprs[pos]; - auto ii = ivs[it.index()]; - auto N = memRefSizes[pos]; - clippedScalarAccessExprs[pos] = - select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one)); - if (pos == clippedScalarAccessExprs.size() - 1) { - // If a result of the permutation_map accesses the most minor dimension - // then we record it. - coalescingIndex = it.index(); - } - // Temporarily record already clipped accesses to avoid double clipping. - // TODO(ntv): remove when fully unrolled dimensions are clipped properly. - clipped.insert(pos); - } else { - // Sanity check. - assert(it.value().template cast().getValue() == 0 && - "Expected dim or 0 in permutationMap"); - } +namespace { + +/// Lowers VectorTransferOp into a combination of: +/// 1. local memory allocation; +/// 2. perfect loop nest over: +/// a. scalar load/stores from local buffers (viewed as a scalar memref); +/// a. scalar store/load to original memref (with clipping). +/// 3. vector_load/store +/// 4. local memory deallocation. +/// Minor variations occur depending on whether a VectorTransferReadOp or +/// a VectorTransferWriteOp is rewritten. +template class VectorTransferRewriter { +public: + VectorTransferRewriter(VectorTransferOpTy *transfer, + MLFuncLoweringRewriter *rewriter, + MLFuncGlobalLoweringState *state); + + /// Used for staging the transfer in a local scalar buffer. + MemRefType tmpMemRefType() { + auto vectorType = transfer->getVectorType(); + return MemRefType::get(vectorType.getShape(), vectorType.getElementType(), + {}, 0); } + /// View of tmpMemRefType as one vector, used in vector load/store to tmp + /// buffer. + MemRefType vectorMemRefType() { + return MemRefType::get({1}, transfer->getVectorType(), {}, 0); + } + /// Performs the rewrite. + void rewrite(); - // At this point, fully unrolled dimensions have not been clipped because they - // do not appear in the permutation map. As a consequence they may access out - // of bounds. We currently do not have enough information to determine which - // of those access dimensions have been fully unrolled. - // Clip one more time to ensure correctness for fully-unrolled dimensions. - // TODO(ntv): clip just what is needed once we pass the proper information. - // TODO(ntv): when we get there, also ensure we only clip when dimensions are - // not divisible (i.e. simple test that can be hoisted outside loop). - for (unsigned pos = 0; pos < clippedScalarAccessExprs.size(); ++pos) { - if (clipped.count(pos) > 0) { +private: + VectorTransferOpTy *transfer; + MLFuncLoweringRewriter *rewriter; + MLFuncGlobalLoweringState *state; +}; +} // end anonymous namespace + +/// Analyzes the `transfer` to find an access dimension along the fastest remote +/// MemRef dimension. If such a dimension with coalescing properties is found, +/// `pivs` and `vectorView` are swapped so that the invocation of +/// LoopNestBuilder captures it in the innermost loop. +template +void coalesceCopy(VectorTransferOpTy *transfer, + SmallVectorImpl *pivs, + edsc::VectorView *vectorView) { + // rank of the remote memory access, coalescing behavior occurs on the + // innermost memory dimension. + auto remoteRank = transfer->getMemRefType().getRank(); + // Iterate over the results expressions of the permutation map to determine + // the loop order for creating pointwise copies between remote and local + // memories. + int coalescedIdx = -1; + auto exprs = transfer->getPermutationMap().getResults(); + for (auto en : llvm::enumerate(exprs)) { + auto dim = en.value().template dyn_cast(); + if (!dim) { continue; } - auto i = clippedScalarAccessExprs[pos]; - auto N = memRefSizes[pos]; - clippedScalarAccessExprs[pos] = - select(i < zero, zero, select(i < N, i, N - one)); + auto memRefDim = dim.getPosition(); + if (memRefDim == remoteRank - 1) { + // memRefDim has coalescing properties, it should be swapped in the last + // position. + assert(coalescedIdx == -1 && "Unexpected > 1 coalesced indices"); + coalescedIdx = en.index(); + } + } + if (coalescedIdx >= 0) { + std::swap(pivs->back(), (*pivs)[coalescedIdx]); + vectorView->swapRanges(pivs->size() - 1, coalescedIdx); + } +} + +/// Emits remote memory accesses that are clipped to the boundaries of the +/// MemRef. +template +static llvm::SmallVector +clip(VectorTransferOpTy *transfer, edsc::MemRefView &view, + ArrayRef ivs) { + using namespace mlir::edsc; + using namespace edsc::op; + using edsc::intrinsics::select; + + IndexHandle zero(index_t(0)), one(index_t(1)); + llvm::SmallVector memRefAccess(transfer->getIndices()); + llvm::SmallVector clippedScalarAccessExprs( + memRefAccess.size(), edsc::IndexHandle()); + + // Indices accessing to remote memory are clipped and their expressions are + // returned in clippedScalarAccessExprs. + for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size(); + ++memRefDim) { + // Linear search on a small number of entries. + int loopIndex = -1; + auto exprs = transfer->getPermutationMap().getResults(); + for (auto en : llvm::enumerate(exprs)) { + auto expr = en.value(); + auto dim = expr.template dyn_cast(); + // Sanity check. + assert(dim || expr.template cast().getValue() == 0 && + "Expected dim or 0 in permutationMap"); + if (dim && memRefDim == dim.getPosition()) { + loopIndex = en.index(); + break; + } + } + + // We cannot distinguish atm between unrolled dimensions that implement + // the "always full" tile abstraction and need clipping from the other + // ones. So we conservatively clip everything. + auto N = view.ub(memRefDim); + auto i = memRefAccess[memRefDim]; + if (loopIndex < 0) { + clippedScalarAccessExprs[memRefDim] = + select(i < zero, zero, select(i < N, i, N - one)); + } else { + auto ii = ivs[loopIndex]; + clippedScalarAccessExprs[memRefDim] = + select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one)); + } } - // Create the proper bindables for lbs, ubs and steps. Additionally, if we - // recorded a coalescing index, permute the loop informations. - auto lbs = makeNewExprs(ivs.size(), this->rewriter->getIndexType()); - auto ubs = copyExprs(vectorSizes); - auto steps = makeNewExprs(ivs.size(), this->rewriter->getIndexType()); - if (coalescingIndex >= 0) { - std::swap(ivs[coalescingIndex], ivs.back()); - std::swap(lbs[coalescingIndex], lbs.back()); - std::swap(ubs[coalescingIndex], ubs.back()); - std::swap(steps[coalescingIndex], steps.back()); - } - emitter - .template bindZipRangeConstants( - llvm::zip(lbs, SmallVector(ivs.size(), 0))) - .template bindZipRangeConstants( - llvm::zip(steps, SmallVector(ivs.size(), 1))); - - return VectorTransferAccessInfo{ivs, - copyExprs(lbs), - ubs, - copyExprs(steps), - clippedScalarAccessExprs, - tmpAccessExprs}; + return clippedScalarAccessExprs; } template VectorTransferRewriter::VectorTransferRewriter( VectorTransferOpTy *transfer, MLFuncLoweringRewriter *rewriter, MLFuncGlobalLoweringState *state) - : transfer(transfer), rewriter(rewriter), state(state), - memrefType(transfer->getMemRefType()), memrefShape(memrefType.getShape()), - vectorType(transfer->getVectorType()), vectorShape(vectorType.getShape()), - permutationMap(transfer->getPermutationMap()), - tmpMemRefType( - MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0)), - vectorMemRefType(MemRefType::get({1}, vectorType, {}, 0)), - emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())), - vectorSizes( - edsc::makeNewExprs(vectorShape.size(), rewriter->getIndexType())), - zero(emitter.zero()), one(emitter.one()), - scalarMemRef(transfer->getMemRefType()) { - // Bind the Bindable. - SmallVector transferIndices(transfer->getIndices()); - accesses = edsc::makeNewExprs(transferIndices.size(), - this->rewriter->getIndexType()); - emitter.bind(edsc::Bindable(scalarMemRef), transfer->getMemRef()) - .template bindZipRangeConstants( - llvm::zip(vectorSizes, vectorShape)) - .template bindZipRange(llvm::zip(accesses, transfer->getIndices())); -}; + : transfer(transfer), rewriter(rewriter), state(state){}; /// Lowers VectorTransferReadOp into a combination of: /// 1. local memory allocation; @@ -313,38 +265,39 @@ VectorTransferRewriter::VectorTransferRewriter( /// TODO(ntv): support non-data-parallel operations. template <> void VectorTransferRewriter::rewrite() { using namespace mlir::edsc; + using namespace mlir::edsc::op; + using namespace mlir::edsc::intrinsics; - // Build the AccessInfo which contain all the information needed to build the - // perfectly nest loop nest to perform clipped reads and local writes. - auto accessInfo = makeVectorTransferAccessInfo(); + // 1. Setup all the captures. + ScopedContext scope(FuncBuilder(transfer->getInstruction()), + transfer->getLoc()); + IndexedValue remote(transfer->getMemRef()); + MemRefView view(transfer->getMemRef()); + VectorView vectorView(transfer->getVector()); + SmallVector ivs(vectorView.rank()); + SmallVector pivs; + for (auto &idx : ivs) { + pivs.push_back(&idx); + } + coalesceCopy(transfer, &pivs, &vectorView); - // clang-format off - auto &ivs = accessInfo.ivs; - auto &lbs = accessInfo.lowerBoundsExprs; - auto &ubs = accessInfo.upperBoundsExprs; - auto &steps = accessInfo.stepExprs; + auto lbs = vectorView.getLbs(); + auto ubs = vectorView.getUbs(); + auto steps = vectorView.getSteps(); - auto vectorType = this->transfer->getVectorType(); - auto scalarType = this->transfer->getMemRefType().getElementType(); - - Expr scalarValue(scalarType), vectorValue(vectorType), tmpAlloc(tmpMemRefType), tmpDealloc(Type{}), vectorView(vectorMemRefType); - auto block = edsc::block({ - tmpAlloc = alloc(tmpMemRefType), - vectorView = vector_type_cast(Expr(tmpAlloc), vectorMemRefType), - For(ivs, lbs, ubs, steps, { - scalarValue = load(scalarMemRef, accessInfo.clippedScalarAccessExprs), - store(scalarValue, tmpAlloc, accessInfo.tmpAccessExprs), - }), - vectorValue = load(vectorView, {zero}), - tmpDealloc = dealloc(tmpAlloc) + // 2. Emit alloc-copy-load-dealloc. + ValueHandle tmp = alloc(tmpMemRefType()); + IndexedValue local(tmp); + ValueHandle vec = vector_type_cast(tmp, vectorMemRefType()); + LoopNestBuilder(pivs, lbs, ubs, steps)({ + // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). + local(ivs) = remote(clip(transfer, view, ivs)), }); - // clang-format on + ValueHandle vectorValue = LOAD(vec, {index_t(0)}); + (dealloc(tmp)); // vexing parse - // Emit the MLIR. - emitter.emitStmts(block.getBody()); - - // Finalize rewriting. - transfer->replaceAllUsesWith(emitter.getValue(vectorValue)); + // 3. Propagate. + transfer->replaceAllUsesWith(vectorValue.getValue()); transfer->erase(); } @@ -368,37 +321,38 @@ template <> void VectorTransferRewriter::rewrite() { /// TODO(ntv): support non-data-parallel operations. template <> void VectorTransferRewriter::rewrite() { using namespace mlir::edsc; + using namespace mlir::edsc::op; + using namespace mlir::edsc::intrinsics; - // Build the AccessInfo which contain all the information needed to build the - // perfectly nest loop nest to perform local reads and clipped writes. - auto accessInfo = makeVectorTransferAccessInfo(); + // 1. Setup all the captures. + ScopedContext scope(FuncBuilder(transfer->getInstruction()), + transfer->getLoc()); + IndexedValue remote(transfer->getMemRef()); + MemRefView view(transfer->getMemRef()); + ValueHandle vectorValue(transfer->getVector()); + VectorView vectorView(transfer->getVector()); + SmallVector ivs(vectorView.rank()); + SmallVector pivs; + for (auto &idx : ivs) { + pivs.push_back(&idx); + } + coalesceCopy(transfer, &pivs, &vectorView); - // Bind vector value for the vector_transfer_write. - Expr vectorValue(transfer->getVectorType()); - emitter.bind(Bindable(vectorValue), transfer->getVector()); + auto lbs = vectorView.getLbs(); + auto ubs = vectorView.getUbs(); + auto steps = vectorView.getSteps(); - // clang-format off - auto &ivs = accessInfo.ivs; - auto &lbs = accessInfo.lowerBoundsExprs; - auto &ubs = accessInfo.upperBoundsExprs; - auto &steps = accessInfo.stepExprs; - auto scalarType = tmpMemRefType.getElementType(); - Expr scalarValue(scalarType), tmpAlloc(tmpMemRefType), tmpDealloc(Type{}), vectorView(vectorMemRefType); - auto block = edsc::block({ - tmpAlloc = alloc(tmpMemRefType), - vectorView = vector_type_cast(tmpAlloc, vectorMemRefType), - store(vectorValue, vectorView, MutableArrayRef{zero}), - For(ivs, lbs, ubs, steps, { - scalarValue = load(tmpAlloc, accessInfo.tmpAccessExprs), - store(scalarValue, scalarMemRef, accessInfo.clippedScalarAccessExprs), - }), - tmpDealloc = dealloc(tmpAlloc)}); - // clang-format on + // 2. Emit alloc-store-copy-dealloc. + ValueHandle tmp = alloc(tmpMemRefType()); + IndexedValue local(tmp); + ValueHandle vec = vector_type_cast(tmp, vectorMemRefType()); + STORE(vectorValue, vec, {index_t(0)}); + LoopNestBuilder(pivs, lbs, ubs, steps)({ + // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). + remote(clip(transfer, view, ivs)) = local(ivs), + }); + (dealloc(tmp)); // vexing parse... - // Emit the MLIR. - emitter.emitStmts(block.getBody()); - - // Finalize rewriting. transfer->erase(); } @@ -430,9 +384,6 @@ struct LowerVectorTransfersPass applyMLPatternsGreedily, VectorTransferExpander>(f); } - - // Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit. - edsc::ScopedEDSCContext raiiContext; }; } // end anonymous namespace diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 3bb22f0d8aa6..19a1284d87c6 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -338,10 +338,10 @@ TEST_FUNC(builder_helpers) { step2 = vA.step(2); LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})({ LoopBuilder(&k1, lb2, ub2, step2)({ - C({i, j, k1}) = f7 + A({i, j, k1}) + B({i, j, k1}), + C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1), }), LoopBuilder(&k2, lb2, ub2, step2)({ - C({i, j, k2}) += A({i, j, k2}) + B({i, j, k2}), + C(i, j, k2) += A(i, j, k2) + B(i, j, k2), }), }); diff --git a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir index b82ac08fe595..59287791e140 100644 --- a/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir +++ b/mlir/test/Transforms/Vectorize/lower_vector_transfers.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -lower-vector-transfers | FileCheck %s // CHECK: #[[ADD:map[0-9]+]] = (d0, d1) -> (d0 + d1) -// CHECK: #[[SUB:map[0-9]+]] = (d0, d1) -> (d0 - d1) +// CHECK: #[[SUB:map[0-9]+]] = (d0) -> (d0 - 1) // CHECK-LABEL: func @materialize_read_1d() { func @materialize_read_1d() { @@ -58,49 +58,53 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 { // CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 { // CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 { - // CHECK-NEXT: %[[C0:.*]] = constant 0 : index - // CHECK-NEXT: %[[C1:.*]] = constant 1 : index - // CHECK: {{.*}} = dim %0, 0 : memref - // CHECK-NEXT: {{.*}} = dim %0, 1 : memref - // CHECK-NEXT: {{.*}} = dim %0, 2 : memref - // CHECK-NEXT: {{.*}} = dim %0, 3 : memref + // CHECK: %[[D0:.*]] = dim %0, 0 : memref + // CHECK-NEXT: %[[D1:.*]] = dim %0, 1 : memref + // CHECK-NEXT: %[[D2:.*]] = dim %0, 2 : memref + // CHECK-NEXT: %[[D3:.*]] = dim %0, 3 : memref // CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32> // CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast %[[ALLOC]] : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> // CHECK-NEXT: for %[[I4:.*]] = 0 to 3 { // CHECK-NEXT: for %[[I5:.*]] = 0 to 4 { // CHECK-NEXT: for %[[I6:.*]] = 0 to 5 { - // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] + // CHECK-NEXT: %[[C0:.*]] = constant 0 : index + // CHECK-NEXT: %[[C1:.*]] = constant 1 : index + // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index - // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] + // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index - // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] - // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]] - // CHECK-NEXT: {{.*}} = select + // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) + // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D0]]) // CHECK-NEXT: {{.*}} = select + // CHECK-NEXT: %[[L0:.*]] = select + // // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index - // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]] - // CHECK-NEXT: {{.*}} = select + // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D1]]) // CHECK-NEXT: {{.*}} = select + // CHECK-NEXT: %[[L1:.*]] = select + // // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index - // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]] + // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D2]]) // CHECK-NEXT: {{.*}} = select - // CHECK-NEXT: {{.*}} = select - // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] + // CHECK-NEXT: %[[L2:.*]] = select + // + // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index - // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] - // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index - // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]] - // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]] - // CHECK-NEXT: {{.*}} = select {{.*}} : index - // CHECK-NEXT: {{.*}} = select {{.*}} : index - // CHECK-NEXT: {{.*}} = load %0[{{.*}}] : memref - // CHECK-NEXT: store {{.*}}, %[[ALLOC]][{{.*}}] : memref<5x4x3xf32> + // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]]) + // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index + // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]]) + // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D3]]) + // CHECK-NEXT: {{.*}} = select + // CHECK-NEXT: %[[L3:.*]] = select + // + // CHECK-NEXT: {{.*}} = load %0[%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : memref + // CHECK-NEXT: store {{.*}}, %[[ALLOC]][%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32> // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } - // CHECK-NEXT: {{.*}} = load %[[VECTOR_VIEW]][%[[C0]]] : memref<1xvector<5x4x3xf32>> + // CHECK: {{.*}} = load %[[VECTOR_VIEW]][{{.*}}] : memref<1xvector<5x4x3xf32>> // CHECK-NEXT: dealloc %[[ALLOC]] : memref<5x4x3xf32> // CHECK-NEXT: } // CHECK-NEXT: } @@ -108,6 +112,10 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT:} + + // Check that I0 + I4 (of size 3) read from first index load(L0, ...) and write into last index store(..., I4) + // Check that I3 + I6 (of size 5) read from last index load(..., L3) and write into first index store(I6, ...) + // Other dimensions are just accessed with I1, I2 resp. %A = alloc (%M, %N, %O, %P) : memref for %i0 = 0 to %M step 3 { for %i1 = 0 to %N { @@ -129,49 +137,53 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 step 4 { // CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 { // CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 { - // CHECK-NEXT: %[[C0:.*]] = constant 0 : index - // CHECK-NEXT: %[[C1:.*]] = constant 1 : index - // CHECK: {{.*}} = dim %0, 0 : memref - // CHECK-NEXT: {{.*}} = dim %0, 1 : memref - // CHECK-NEXT: {{.*}} = dim %0, 2 : memref - // CHECK-NEXT: {{.*}} = dim %0, 3 : memref + // CHECK: %[[D0:.*]] = dim %0, 0 : memref + // CHECK-NEXT: %[[D1:.*]] = dim %0, 1 : memref + // CHECK-NEXT: %[[D2:.*]] = dim %0, 2 : memref + // CHECK-NEXT: %[[D3:.*]] = dim %0, 3 : memref // CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32> // CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast {{.*}} : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> - // CHECK-NEXT: store %cst, {{.*}}[%[[C0]]] : memref<1xvector<5x4x3xf32>> + // CHECK: store %cst, {{.*}} : memref<1xvector<5x4x3xf32>> // CHECK-NEXT: for %[[I4:.*]] = 0 to 3 { // CHECK-NEXT: for %[[I5:.*]] = 0 to 4 { // CHECK-NEXT: for %[[I6:.*]] = 0 to 5 { - // CHECK-NEXT: {{.*}} = load {{.*}}[%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32> + // CHECK-NEXT: %[[C0:.*]] = constant 0 : index + // CHECK-NEXT: %[[C1:.*]] = constant 1 : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]]) - // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]({{.*}}, %[[C1]]) + // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D0]]) // CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index - // CHECK-NEXT: {{.*}} = select {{.*}}, %[[C0]], {{.*}} : index + // CHECK-NEXT: %[[S0:.*]] = select {{.*}}, %[[C0]], {{.*}} : index + // // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I1]], %[[I5]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I1]], %[[I5]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I1]], %[[I5]]) - // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]({{.*}}, %[[C1]]) + // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D1]]) // CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index - // CHECK-NEXT: {{.*}} = select {{.*}}, %[[C0]], {{.*}} : index + // CHECK-NEXT: %[[S1:.*]] = select {{.*}}, %[[C0]], {{.*}} : index + // // CHECK-NEXT: {{.*}} = cmpi "slt", %[[I2]], %[[C0]] : index // CHECK-NEXT: {{.*}} = cmpi "slt", %[[I2]], %3 : index - // CHECK-NEXT: {{.*}} = affine.apply #map{{.*}}(%3, %[[C1]]) + // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D2]]) // CHECK-NEXT: {{.*}} = select {{.*}}, %[[I2]], {{.*}} : index - // CHECK-NEXT: {{.*}} = select {{.*}}, %[[C0]], {{.*}} : index + // CHECK-NEXT: %[[S2:.*]] = select {{.*}}, %[[C0]], {{.*}} : index + // // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]]) // CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index // CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]]) - // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]({{.*}}, %[[C1]]) + // CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D3]]) // CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index - // CHECK-NEXT: {{.*}} = select {{.*}}, %[[C0]], {{.*}} : index - // CHECK-NEXT: store {{.*}}, {{.*}}[{{.*}}, {{.*}}, {{.*}}, {{.*}}] : memref + // CHECK-NEXT: %[[S3:.*]] = select {{.*}}, %[[C0]], {{.*}} : index + // + // CHECK-NEXT: {{.*}} = load {{.*}}[%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32> + // CHECK: store {{.*}}, {{.*}}[%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : memref // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } @@ -182,6 +194,11 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT:} + // + // Check that I0 + I4 (of size 3) read from last index load(..., I4) and write into first index store(S0, ...) + // Check that I1 + I5 (of size 4) read from second index load(..., I5, ...) and write into second index store(..., S1, ...) + // Check that I3 + I6 (of size 5) read from first index load(I6, ...) and write into last index store(..., S3) + // Other dimension is just accessed with I2. %A = alloc (%M, %N, %O, %P) : memref %f1 = constant splat, 1.000000e+00> : vector<5x4x3xf32> for %i0 = 0 to %M step 3 {