mirror of
https://github.com/intel/llvm.git
synced 2026-02-01 08:56:15 +08:00
Port LowerVectorTransfers from EDSC + AST to declarative builders
This CL removes the dependency of LowerVectorTransfers on the AST version of EDSCs which will be retired. This exhibited a pretty fundamental staging difference in AST-based vs declarative based emission. Since the delayed creation with an AST was staged, the loop order came into existence after the clipping expressions were computed. This now changes as the loops first need to be created declaratively in fixed order and then the clipping expressions are created. Also, due to lack of staging, coalescing cannot be done on the fly anymore and needs to be done either as a pre-pass (current implementation) or as a local transformation on the generated IR (future work). Tests are updated accordingly. PiperOrigin-RevId: 238971631
This commit is contained in:
committed by
jpienaar
parent
6810c8bdc1
commit
f43388e4ce
@@ -26,6 +26,7 @@
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/SuperVectorOps/SuperVectorOps.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -291,6 +292,14 @@ public:
|
||||
/// assigning.
|
||||
ValueHandle &operator=(const ValueHandle &other);
|
||||
|
||||
/// Provide a swap operator.
|
||||
void swap(ValueHandle &other) {
|
||||
if (this == &other)
|
||||
return;
|
||||
std::swap(t, other.t);
|
||||
std::swap(v, other.v);
|
||||
}
|
||||
|
||||
/// Implicit conversion useful for automatic conversion to Container<Value*>.
|
||||
operator Value *() const { return getValue(); }
|
||||
|
||||
@@ -310,11 +319,14 @@ public:
|
||||
ArrayRef<NamedAttribute> attributes = {});
|
||||
|
||||
bool hasValue() const { return v != nullptr; }
|
||||
Value *getValue() const { return v; }
|
||||
Value *getValue() const {
|
||||
assert(hasValue() && "Unexpected null value;");
|
||||
return v;
|
||||
}
|
||||
bool hasType() const { return t != Type(); }
|
||||
Type getType() const { return t; }
|
||||
|
||||
private:
|
||||
protected:
|
||||
ValueHandle() : t(), v(nullptr) {}
|
||||
|
||||
Type t;
|
||||
@@ -442,6 +454,12 @@ ValueHandle operator!(ValueHandle value);
|
||||
ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator||(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator^(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator==(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator<(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator>(ValueHandle lhs, ValueHandle rhs);
|
||||
ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs);
|
||||
|
||||
} // namespace op
|
||||
} // namespace edsc
|
||||
|
||||
@@ -25,6 +25,8 @@
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace edsc {
|
||||
|
||||
@@ -45,7 +47,48 @@ struct IndexHandle : public ValueHandle {
|
||||
assert(v->getType() == ScopedContext::getBuilder()->getIndexType() &&
|
||||
"Expected index type");
|
||||
}
|
||||
explicit IndexHandle(ValueHandle v) : ValueHandle(v) {}
|
||||
explicit IndexHandle(ValueHandle v) : ValueHandle(v) {
|
||||
assert(v.getType() == ScopedContext::getBuilder()->getIndexType() &&
|
||||
"Expected index type");
|
||||
}
|
||||
IndexHandle &operator=(const ValueHandle &v) {
|
||||
assert(v.getType() == ScopedContext::getBuilder()->getIndexType() &&
|
||||
"Expected index type");
|
||||
/// Creating a new IndexHandle(v) and then std::swap rightly complains the
|
||||
/// binding has already occurred and that we should use another name.
|
||||
this->t = v.getType();
|
||||
this->v = v.getValue();
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
// Base class for MemRefView and VectorView.
|
||||
class View {
|
||||
public:
|
||||
unsigned rank() const { return lbs.size(); }
|
||||
ValueHandle lb(unsigned idx) { return lbs[idx]; }
|
||||
ValueHandle ub(unsigned idx) { return ubs[idx]; }
|
||||
int64_t step(unsigned idx) { return steps[idx]; }
|
||||
std::tuple<ValueHandle, ValueHandle, int64_t> range(unsigned idx) {
|
||||
return std::make_tuple(lbs[idx], ubs[idx], steps[idx]);
|
||||
}
|
||||
void swapRanges(unsigned i, unsigned j) {
|
||||
llvm::errs() << "\nSWAP: " << i << " " << j;
|
||||
if (i == j)
|
||||
return;
|
||||
lbs[i].swap(lbs[j]);
|
||||
ubs[i].swap(ubs[j]);
|
||||
std::swap(steps[i], steps[j]);
|
||||
}
|
||||
|
||||
ArrayRef<ValueHandle> getLbs() { return lbs; }
|
||||
ArrayRef<ValueHandle> getUbs() { return ubs; }
|
||||
ArrayRef<int64_t> getSteps() { return steps; }
|
||||
|
||||
protected:
|
||||
SmallVector<ValueHandle, 8> lbs;
|
||||
SmallVector<ValueHandle, 8> ubs;
|
||||
SmallVector<int64_t, 8> steps;
|
||||
};
|
||||
|
||||
/// A MemRefView represents the information required to step through a
|
||||
@@ -53,35 +96,32 @@ struct IndexHandle : public ValueHandle {
|
||||
/// Fortran subarray model.
|
||||
/// At the moment it can only capture a MemRef with an identity layout map.
|
||||
// TODO(ntv): Support MemRefs with layoutMaps.
|
||||
class MemRefView {
|
||||
class MemRefView : public View {
|
||||
public:
|
||||
explicit MemRefView(Value *v);
|
||||
MemRefView(const MemRefView &) = default;
|
||||
MemRefView &operator=(const MemRefView &) = default;
|
||||
|
||||
unsigned rank() const { return lbs.size(); }
|
||||
unsigned fastestVarying() const { return rank() - 1; }
|
||||
|
||||
IndexHandle lb(unsigned idx) { return lbs[idx]; }
|
||||
IndexHandle ub(unsigned idx) { return ubs[idx]; }
|
||||
int64_t step(unsigned idx) { return steps[idx]; }
|
||||
std::tuple<IndexHandle, IndexHandle, int64_t> range(unsigned idx) {
|
||||
return std::make_tuple(lbs[idx], ubs[idx], steps[idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
friend IndexedValue;
|
||||
|
||||
ValueHandle base;
|
||||
SmallVector<IndexHandle, 8> lbs;
|
||||
SmallVector<IndexHandle, 8> ubs;
|
||||
SmallVector<int64_t, 8> steps;
|
||||
};
|
||||
|
||||
ValueHandle operator+(ValueHandle v, IndexedValue i);
|
||||
ValueHandle operator-(ValueHandle v, IndexedValue i);
|
||||
ValueHandle operator*(ValueHandle v, IndexedValue i);
|
||||
ValueHandle operator/(ValueHandle v, IndexedValue i);
|
||||
/// A VectorView represents the information required to step through a
|
||||
/// Vector accessing each scalar element at a time. It is the counterpart of
|
||||
/// a MemRefView but for vectors. This exists purely for boilerplate avoidance.
|
||||
class VectorView : public View {
|
||||
public:
|
||||
explicit VectorView(Value *v);
|
||||
VectorView(const VectorView &) = default;
|
||||
VectorView &operator=(const VectorView &) = default;
|
||||
|
||||
private:
|
||||
friend IndexedValue;
|
||||
ValueHandle base;
|
||||
};
|
||||
|
||||
/// This helper class is an abstraction over memref, that purely for sugaring
|
||||
/// purposes and allows writing compact expressions such as:
|
||||
@@ -97,32 +137,54 @@ ValueHandle operator/(ValueHandle v, IndexedValue i);
|
||||
/// converting an IndexedValue to a ValueHandle emits an actual load operation.
|
||||
struct IndexedValue {
|
||||
explicit IndexedValue(Type t) : base(t) {}
|
||||
explicit IndexedValue(Value *v, llvm::ArrayRef<ValueHandle> indices = {})
|
||||
: IndexedValue(ValueHandle(v), indices) {}
|
||||
explicit IndexedValue(ValueHandle v, llvm::ArrayRef<ValueHandle> indices = {})
|
||||
: base(v), indices(indices.begin(), indices.end()) {}
|
||||
explicit IndexedValue(Value *v) : IndexedValue(ValueHandle(v)) {}
|
||||
explicit IndexedValue(ValueHandle v) : base(v) {}
|
||||
|
||||
IndexedValue(const IndexedValue &rhs) = default;
|
||||
IndexedValue &operator=(const IndexedValue &rhs) = default;
|
||||
|
||||
ValueHandle operator()() { return ValueHandle(*this); }
|
||||
/// Returns a new `IndexedValue`.
|
||||
IndexedValue operator()(llvm::ArrayRef<ValueHandle> indices = {}) {
|
||||
IndexedValue operator()(ValueHandle index) {
|
||||
IndexedValue res(base);
|
||||
res.indices.push_back(index);
|
||||
return res;
|
||||
}
|
||||
template <typename... Args>
|
||||
IndexedValue operator()(ValueHandle index, Args... indices) {
|
||||
return IndexedValue(base, index).append(indices...);
|
||||
}
|
||||
IndexedValue operator()(llvm::ArrayRef<ValueHandle> indices) {
|
||||
return IndexedValue(base, indices);
|
||||
}
|
||||
IndexedValue operator()(llvm::ArrayRef<IndexHandle> indices) {
|
||||
return IndexedValue(
|
||||
base, llvm::ArrayRef<ValueHandle>(indices.begin(), indices.end()));
|
||||
}
|
||||
|
||||
/// Emits a `store`.
|
||||
// NOLINTNEXTLINE: unconventional-assign-operator
|
||||
InstructionHandle operator=(const IndexedValue &rhs) {
|
||||
ValueHandle rrhs(rhs);
|
||||
assert(getBase().getType().cast<MemRefType>().getRank() == indices.size() &&
|
||||
"Unexpected number of indices to store in MemRef");
|
||||
return intrinsics::STORE(rrhs, getBase(), indices);
|
||||
}
|
||||
// NOLINTNEXTLINE: unconventional-assign-operator
|
||||
InstructionHandle operator=(ValueHandle rhs) {
|
||||
assert(getBase().getType().cast<MemRefType>().getRank() == indices.size() &&
|
||||
"Unexpected number of indices to store in MemRef");
|
||||
return intrinsics::STORE(rhs, getBase(), indices);
|
||||
}
|
||||
|
||||
ValueHandle getBase() const { return base; }
|
||||
|
||||
/// Emits a `load` when converting to a ValueHandle.
|
||||
explicit operator ValueHandle() {
|
||||
operator ValueHandle() const {
|
||||
assert(getBase().getType().cast<MemRefType>().getRank() == indices.size() &&
|
||||
"Unexpected number of indices to store in MemRef");
|
||||
return intrinsics::LOAD(getBase(), indices);
|
||||
}
|
||||
|
||||
ValueHandle getBase() const { return base; }
|
||||
|
||||
/// Operator overloadings.
|
||||
ValueHandle operator+(ValueHandle e);
|
||||
ValueHandle operator-(ValueHandle e);
|
||||
@@ -158,6 +220,17 @@ struct IndexedValue {
|
||||
}
|
||||
|
||||
private:
|
||||
IndexedValue(ValueHandle base, ArrayRef<ValueHandle> indices)
|
||||
: base(base), indices(indices.begin(), indices.end()) {}
|
||||
|
||||
IndexedValue &append() { return *this; }
|
||||
|
||||
template <typename T, typename... Args>
|
||||
IndexedValue &append(T index, Args... indices) {
|
||||
this->indices.push_back(static_cast<ValueHandle>(index));
|
||||
append(indices...);
|
||||
return *this;
|
||||
}
|
||||
ValueHandle base;
|
||||
llvm::SmallVector<ValueHandle, 8> indices;
|
||||
};
|
||||
|
||||
@@ -23,10 +23,14 @@
|
||||
#ifndef MLIR_EDSC_INTRINSICS_H_
|
||||
#define MLIR_EDSC_INTRINSICS_H_
|
||||
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class MemRefType;
|
||||
class Type;
|
||||
|
||||
namespace edsc {
|
||||
|
||||
class BlockHandle;
|
||||
@@ -94,9 +98,34 @@ InstructionHandle COND_BR(ValueHandle cond, BlockHandle *trueBranch,
|
||||
ArrayRef<ValueHandle *> falseCaptures,
|
||||
ArrayRef<ValueHandle> falseOperands);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// TODO(ntv): Intrinsics below this line should be TableGen'd.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Helper variadic abstraction to allow extending to any MLIR op without
|
||||
/// boilerplate or Tablegen.
|
||||
/// Arguably a builder is not a ValueHandle but in practice it is only used as
|
||||
/// an alias to a notional ValueHandle<Op>.
|
||||
/// Implementing it as a subclass allows it to compose all the way to Value*.
|
||||
/// Without subclassing, implicit conversion to Value* would fail when composing
|
||||
/// in patterns such as: `select(a, b, select(c, d, e))`.
|
||||
template <typename Op> struct EDSCValueBuilder : public ValueHandle {
|
||||
template <typename... Args>
|
||||
EDSCValueBuilder(Args... args)
|
||||
: ValueHandle(ValueHandle::create<Op>(std::forward<Args>(args)...)) {}
|
||||
EDSCValueBuilder() = delete;
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
struct EDSCInstructionBuilder : public InstructionHandle {
|
||||
template <typename... Args>
|
||||
EDSCInstructionBuilder(Args... args)
|
||||
: InstructionHandle(
|
||||
InstructionHandle::create<Op>(std::forward<Args>(args)...)) {}
|
||||
EDSCInstructionBuilder() = delete;
|
||||
};
|
||||
|
||||
using alloc = EDSCValueBuilder<AllocOp>;
|
||||
using dealloc = EDSCInstructionBuilder<DeallocOp>;
|
||||
using select = EDSCValueBuilder<SelectOp>;
|
||||
using vector_type_cast = EDSCValueBuilder<VectorTypeCastOp>;
|
||||
|
||||
/// Builds an mlir::LoadOp with the proper `operands` that each must have
|
||||
/// captured an mlir::Value*.
|
||||
/// Returns a ValueHandle to the produced mlir::Value*.
|
||||
@@ -104,19 +133,16 @@ ValueHandle LOAD(ValueHandle base, llvm::ArrayRef<ValueHandle> indices);
|
||||
|
||||
/// Builds an mlir::ReturnOp with the proper `operands` that each must have
|
||||
/// captured an mlir::Value*.
|
||||
/// Returns an empty ValueHandle.
|
||||
/// Returns an InstructionHandle.
|
||||
InstructionHandle RETURN(llvm::ArrayRef<ValueHandle> operands);
|
||||
|
||||
/// Builds an mlir::StoreOp with the proper `operands` that each must have
|
||||
/// captured an mlir::Value*.
|
||||
/// Returns an empty ValueHandle.
|
||||
/// Returns an InstructionHandle.
|
||||
InstructionHandle STORE(ValueHandle value, ValueHandle base,
|
||||
llvm::ArrayRef<ValueHandle> indices);
|
||||
|
||||
} // namespace intrinsics
|
||||
|
||||
} // namespace edsc
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_EDSC_INTRINSICS_H_
|
||||
|
||||
@@ -381,3 +381,36 @@ ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) {
|
||||
ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) {
|
||||
return !(!lhs && !rhs);
|
||||
}
|
||||
|
||||
static ValueHandle createComparisonExpr(CmpIPredicate predicate,
|
||||
ValueHandle lhs, ValueHandle rhs) {
|
||||
auto lhsType = lhs.getType();
|
||||
auto rhsType = rhs.getType();
|
||||
assert(lhsType == rhsType && "cannot mix types in operators");
|
||||
assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) &&
|
||||
"only integer comparisons are supported");
|
||||
|
||||
auto op = ScopedContext::getBuilder()->create<CmpIOp>(
|
||||
ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
|
||||
return ValueHandle(op->getResult());
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) {
|
||||
return createComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
|
||||
return createComparisonExpr(CmpIPredicate::NE, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
|
||||
// TODO(ntv,zinenko): signed by default, how about unsigned?
|
||||
return createComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
|
||||
return createComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
|
||||
return createComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
|
||||
}
|
||||
ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
|
||||
return createComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
|
||||
}
|
||||
|
||||
@@ -22,21 +22,21 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
|
||||
static SmallVector<IndexHandle, 8> getMemRefSizes(Value *memRef) {
|
||||
static SmallVector<ValueHandle, 8> getMemRefSizes(Value *memRef) {
|
||||
MemRefType memRefType = memRef->getType().cast<MemRefType>();
|
||||
|
||||
auto maps = memRefType.getAffineMaps();
|
||||
(void)maps;
|
||||
assert((maps.empty() || (maps.size() == 1 && maps[0].isIdentity())) &&
|
||||
"Layout maps not supported");
|
||||
SmallVector<IndexHandle, 8> res;
|
||||
SmallVector<ValueHandle, 8> res;
|
||||
res.reserve(memRefType.getShape().size());
|
||||
const auto &shape = memRefType.getShape();
|
||||
for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) {
|
||||
if (shape[idx] == -1) {
|
||||
res.push_back(IndexHandle(ValueHandle::create<DimOp>(memRef, idx)));
|
||||
res.push_back(ValueHandle::create<DimOp>(memRef, idx));
|
||||
} else {
|
||||
res.push_back(IndexHandle(static_cast<index_t>(shape[idx])));
|
||||
res.push_back(static_cast<index_t>(shape[idx]));
|
||||
}
|
||||
}
|
||||
return res;
|
||||
@@ -47,12 +47,22 @@ mlir::edsc::MemRefView::MemRefView(Value *v) : base(v) {
|
||||
|
||||
auto memrefSizeValues = getMemRefSizes(v);
|
||||
for (auto &size : memrefSizeValues) {
|
||||
lbs.push_back(IndexHandle(static_cast<index_t>(0)));
|
||||
lbs.push_back(static_cast<index_t>(0));
|
||||
ubs.push_back(size);
|
||||
steps.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
mlir::edsc::VectorView::VectorView(Value *v) : base(v) {
|
||||
auto vectorType = v->getType().cast<VectorType>();
|
||||
|
||||
for (auto s : vectorType.getShape()) {
|
||||
lbs.push_back(static_cast<index_t>(0));
|
||||
ubs.push_back(static_cast<index_t>(s));
|
||||
steps.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Operator overloadings.
|
||||
ValueHandle mlir::edsc::IndexedValue::operator+(ValueHandle e) {
|
||||
using op::operator+;
|
||||
@@ -87,23 +97,3 @@ InstructionHandle mlir::edsc::IndexedValue::operator/=(ValueHandle e) {
|
||||
using op::operator/;
|
||||
return intrinsics::STORE(*this / e, getBase(), indices);
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::operator+(ValueHandle v, IndexedValue i) {
|
||||
using op::operator+;
|
||||
return v + static_cast<ValueHandle>(i);
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::operator-(ValueHandle v, IndexedValue i) {
|
||||
using op::operator-;
|
||||
return v - static_cast<ValueHandle>(i);
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::operator*(ValueHandle v, IndexedValue i) {
|
||||
using op::operator*;
|
||||
return v * static_cast<ValueHandle>(i);
|
||||
}
|
||||
|
||||
ValueHandle mlir::edsc::operator/(ValueHandle v, IndexedValue i) {
|
||||
using op::operator/;
|
||||
return v / static_cast<ValueHandle>(i);
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/SuperVectorOps/SuperVectorOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
|
||||
@@ -25,7 +25,8 @@
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/Analysis/VectorAnalysis.h"
|
||||
#include "mlir/EDSC/MLIREmitter.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
@@ -43,82 +44,14 @@
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
///
|
||||
/// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a
|
||||
/// proper abstraction for the hardware.
|
||||
///
|
||||
/// For now only a simple loop nest is emitted.
|
||||
/// For now, we only emit a simple loop nest that performs clipped pointwise
|
||||
/// copies from a remote to a locally allocated memory.
|
||||
///
|
||||
|
||||
using llvm::dbgs;
|
||||
using llvm::SetVector;
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define DEBUG_TYPE "lower-vector-transfers"
|
||||
|
||||
namespace {
|
||||
/// Helper structure to hold information about loop nest, clipped accesses to
|
||||
/// the original scalar MemRef as well as full accesses to temporary MemRef in
|
||||
/// local storage.
|
||||
struct VectorTransferAccessInfo {
|
||||
// `ivs` are bound for `For` Stmt at `For` Stmt construction time.
|
||||
llvm::SmallVector<edsc::Expr, 8> ivs;
|
||||
llvm::SmallVector<edsc::Expr, 8> lowerBoundsExprs;
|
||||
llvm::SmallVector<edsc::Expr, 8> upperBoundsExprs;
|
||||
llvm::SmallVector<edsc::Expr, 8> stepExprs;
|
||||
llvm::SmallVector<edsc::Expr, 8> clippedScalarAccessExprs;
|
||||
llvm::SmallVector<edsc::Expr, 8> tmpAccessExprs;
|
||||
};
|
||||
|
||||
template <typename VectorTransferOpTy> class VectorTransferRewriter {
|
||||
public:
|
||||
/// Perform the rewrite using the `emitter`.
|
||||
VectorTransferRewriter(VectorTransferOpTy *transfer,
|
||||
MLFuncLoweringRewriter *rewriter,
|
||||
MLFuncGlobalLoweringState *state);
|
||||
|
||||
/// Perform the rewrite using the `emitter`.
|
||||
void rewrite();
|
||||
|
||||
/// Helper class which creates clipped memref accesses to support lowering of
|
||||
/// the vector_transfer operation.
|
||||
VectorTransferAccessInfo makeVectorTransferAccessInfo();
|
||||
|
||||
private:
|
||||
VectorTransferOpTy *transfer;
|
||||
MLFuncLoweringRewriter *rewriter;
|
||||
MLFuncGlobalLoweringState *state;
|
||||
|
||||
MemRefType memrefType;
|
||||
ArrayRef<int64_t> memrefShape;
|
||||
VectorType vectorType;
|
||||
ArrayRef<int64_t> vectorShape;
|
||||
AffineMap permutationMap;
|
||||
|
||||
/// Used for staging the transfer in a local scalar buffer.
|
||||
MemRefType tmpMemRefType;
|
||||
/// View of tmpMemRefType as one vector, used in vector load/store to tmp
|
||||
/// buffer.
|
||||
MemRefType vectorMemRefType;
|
||||
|
||||
// EDSC `emitter` and Exprs that are pre-bound at construction time.
|
||||
edsc::MLIREmitter emitter;
|
||||
// vectorSizes are bound to the actual constant sizes of vectorType.
|
||||
llvm::SmallVector<edsc::Expr, 8> vectorSizes;
|
||||
// accesses are bound to transfer->getIndices()
|
||||
llvm::SmallVector<edsc::Expr, 8> accesses;
|
||||
// `zero` and `one` are bound emitter.zero() and emitter.one().
|
||||
// `scalarMemRef` is bound to `transfer->getMemRef()`.
|
||||
edsc::Expr zero, one, scalarMemRef;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Consider the case:
|
||||
///
|
||||
/// ```mlir {.mlir}
|
||||
@@ -134,141 +67,160 @@ private:
|
||||
/// }}}
|
||||
/// ```
|
||||
///
|
||||
/// The following constructs the `loadAccessExpr` that supports the emission of
|
||||
/// MLIR resembling:
|
||||
/// The rewriters construct loop and indices that access MemRef A in a pattern
|
||||
/// resembling the following (while guaranteeing an always full-tile
|
||||
/// abstraction):
|
||||
///
|
||||
/// ```mlir
|
||||
/// for %d1 = 0 to 256 {
|
||||
/// for %d2 = 0 to 32 {
|
||||
/// ```mlir {.mlir}
|
||||
/// for %d2 = 0 to 256 {
|
||||
/// for %d1 = 0 to 32 {
|
||||
/// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32
|
||||
/// %tmp[%d2, %d1] = %s
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Notice in particular the order of loops iterating over the vector size
|
||||
/// (i.e. 256x32 instead of 32x256). This results in contiguous accesses along
|
||||
/// the most minor dimension of the original scalar tensor. On many hardware
|
||||
/// architectures this will result in better utilization of the underlying
|
||||
/// memory subsystem (e.g. prefetchers, DMAs, #memory transactions, etc...).
|
||||
///
|
||||
/// This additionally performs clipping as described in
|
||||
/// `VectorTransferRewriter<VectorTransferReadOp>::rewrite` by emitting:
|
||||
/// In the current state, only a clipping transfer is implemented by `clip`,
|
||||
/// which creates individual indexing expressions of the form:
|
||||
///
|
||||
/// ```mlir-dsc
|
||||
/// select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one))
|
||||
/// SELECT(i + ii < zero, zero, SELECT(i + ii < N, i + ii, N - one))
|
||||
/// ```
|
||||
template <typename VectorTransferOpTy>
|
||||
VectorTransferAccessInfo
|
||||
VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
|
||||
using namespace mlir::edsc;
|
||||
using namespace edsc::op;
|
||||
|
||||
// Create new Exprs for ivs, they will be bound at `For` Stmt
|
||||
// construction.
|
||||
auto ivs = makeNewExprs(vectorShape.size(), this->rewriter->getIndexType());
|
||||
using namespace mlir;
|
||||
|
||||
// Create and bind Exprs to refer to the Value for memref sizes.
|
||||
auto memRefSizes = emitter.makeBoundMemRefShape(transfer->getMemRef());
|
||||
#define DEBUG_TYPE "lower-vector-transfers"
|
||||
|
||||
// Create the edsc::Expr for the clipped and transposes access expressions
|
||||
// using the permutationMap. Additionally, capture the index accessing the
|
||||
// most minor dimension.
|
||||
int coalescingIndex = -1;
|
||||
auto clippedScalarAccessExprs = copyExprs(accesses);
|
||||
auto tmpAccessExprs = copyExprs(ivs);
|
||||
llvm::DenseSet<unsigned> clipped;
|
||||
for (auto it : llvm::enumerate(permutationMap.getResults())) {
|
||||
if (auto affineExpr = it.value().template dyn_cast<AffineDimExpr>()) {
|
||||
auto pos = affineExpr.getPosition();
|
||||
auto i = clippedScalarAccessExprs[pos];
|
||||
auto ii = ivs[it.index()];
|
||||
auto N = memRefSizes[pos];
|
||||
clippedScalarAccessExprs[pos] =
|
||||
select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one));
|
||||
if (pos == clippedScalarAccessExprs.size() - 1) {
|
||||
// If a result of the permutation_map accesses the most minor dimension
|
||||
// then we record it.
|
||||
coalescingIndex = it.index();
|
||||
}
|
||||
// Temporarily record already clipped accesses to avoid double clipping.
|
||||
// TODO(ntv): remove when fully unrolled dimensions are clipped properly.
|
||||
clipped.insert(pos);
|
||||
} else {
|
||||
// Sanity check.
|
||||
assert(it.value().template cast<AffineConstantExpr>().getValue() == 0 &&
|
||||
"Expected dim or 0 in permutationMap");
|
||||
}
|
||||
namespace {
|
||||
|
||||
/// Lowers VectorTransferOp into a combination of:
|
||||
/// 1. local memory allocation;
|
||||
/// 2. perfect loop nest over:
|
||||
/// a. scalar load/stores from local buffers (viewed as a scalar memref);
|
||||
/// a. scalar store/load to original memref (with clipping).
|
||||
/// 3. vector_load/store
|
||||
/// 4. local memory deallocation.
|
||||
/// Minor variations occur depending on whether a VectorTransferReadOp or
|
||||
/// a VectorTransferWriteOp is rewritten.
|
||||
template <typename VectorTransferOpTy> class VectorTransferRewriter {
|
||||
public:
|
||||
VectorTransferRewriter(VectorTransferOpTy *transfer,
|
||||
MLFuncLoweringRewriter *rewriter,
|
||||
MLFuncGlobalLoweringState *state);
|
||||
|
||||
/// Used for staging the transfer in a local scalar buffer.
|
||||
MemRefType tmpMemRefType() {
|
||||
auto vectorType = transfer->getVectorType();
|
||||
return MemRefType::get(vectorType.getShape(), vectorType.getElementType(),
|
||||
{}, 0);
|
||||
}
|
||||
/// View of tmpMemRefType as one vector, used in vector load/store to tmp
|
||||
/// buffer.
|
||||
MemRefType vectorMemRefType() {
|
||||
return MemRefType::get({1}, transfer->getVectorType(), {}, 0);
|
||||
}
|
||||
/// Performs the rewrite.
|
||||
void rewrite();
|
||||
|
||||
// At this point, fully unrolled dimensions have not been clipped because they
|
||||
// do not appear in the permutation map. As a consequence they may access out
|
||||
// of bounds. We currently do not have enough information to determine which
|
||||
// of those access dimensions have been fully unrolled.
|
||||
// Clip one more time to ensure correctness for fully-unrolled dimensions.
|
||||
// TODO(ntv): clip just what is needed once we pass the proper information.
|
||||
// TODO(ntv): when we get there, also ensure we only clip when dimensions are
|
||||
// not divisible (i.e. simple test that can be hoisted outside loop).
|
||||
for (unsigned pos = 0; pos < clippedScalarAccessExprs.size(); ++pos) {
|
||||
if (clipped.count(pos) > 0) {
|
||||
private:
|
||||
VectorTransferOpTy *transfer;
|
||||
MLFuncLoweringRewriter *rewriter;
|
||||
MLFuncGlobalLoweringState *state;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Analyzes the `transfer` to find an access dimension along the fastest remote
|
||||
/// MemRef dimension. If such a dimension with coalescing properties is found,
|
||||
/// `pivs` and `vectorView` are swapped so that the invocation of
|
||||
/// LoopNestBuilder captures it in the innermost loop.
|
||||
template <typename VectorTransferOpTy>
|
||||
void coalesceCopy(VectorTransferOpTy *transfer,
|
||||
SmallVectorImpl<edsc::ValueHandle *> *pivs,
|
||||
edsc::VectorView *vectorView) {
|
||||
// rank of the remote memory access, coalescing behavior occurs on the
|
||||
// innermost memory dimension.
|
||||
auto remoteRank = transfer->getMemRefType().getRank();
|
||||
// Iterate over the results expressions of the permutation map to determine
|
||||
// the loop order for creating pointwise copies between remote and local
|
||||
// memories.
|
||||
int coalescedIdx = -1;
|
||||
auto exprs = transfer->getPermutationMap().getResults();
|
||||
for (auto en : llvm::enumerate(exprs)) {
|
||||
auto dim = en.value().template dyn_cast<AffineDimExpr>();
|
||||
if (!dim) {
|
||||
continue;
|
||||
}
|
||||
auto i = clippedScalarAccessExprs[pos];
|
||||
auto N = memRefSizes[pos];
|
||||
clippedScalarAccessExprs[pos] =
|
||||
select(i < zero, zero, select(i < N, i, N - one));
|
||||
auto memRefDim = dim.getPosition();
|
||||
if (memRefDim == remoteRank - 1) {
|
||||
// memRefDim has coalescing properties, it should be swapped in the last
|
||||
// position.
|
||||
assert(coalescedIdx == -1 && "Unexpected > 1 coalesced indices");
|
||||
coalescedIdx = en.index();
|
||||
}
|
||||
}
|
||||
if (coalescedIdx >= 0) {
|
||||
std::swap(pivs->back(), (*pivs)[coalescedIdx]);
|
||||
vectorView->swapRanges(pivs->size() - 1, coalescedIdx);
|
||||
}
|
||||
}
|
||||
|
||||
/// Emits remote memory accesses that are clipped to the boundaries of the
|
||||
/// MemRef.
|
||||
template <typename VectorTransferOpTy>
|
||||
static llvm::SmallVector<edsc::ValueHandle, 8>
|
||||
clip(VectorTransferOpTy *transfer, edsc::MemRefView &view,
|
||||
ArrayRef<edsc::IndexHandle> ivs) {
|
||||
using namespace mlir::edsc;
|
||||
using namespace edsc::op;
|
||||
using edsc::intrinsics::select;
|
||||
|
||||
IndexHandle zero(index_t(0)), one(index_t(1));
|
||||
llvm::SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer->getIndices());
|
||||
llvm::SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs(
|
||||
memRefAccess.size(), edsc::IndexHandle());
|
||||
|
||||
// Indices accessing to remote memory are clipped and their expressions are
|
||||
// returned in clippedScalarAccessExprs.
|
||||
for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size();
|
||||
++memRefDim) {
|
||||
// Linear search on a small number of entries.
|
||||
int loopIndex = -1;
|
||||
auto exprs = transfer->getPermutationMap().getResults();
|
||||
for (auto en : llvm::enumerate(exprs)) {
|
||||
auto expr = en.value();
|
||||
auto dim = expr.template dyn_cast<AffineDimExpr>();
|
||||
// Sanity check.
|
||||
assert(dim || expr.template cast<AffineConstantExpr>().getValue() == 0 &&
|
||||
"Expected dim or 0 in permutationMap");
|
||||
if (dim && memRefDim == dim.getPosition()) {
|
||||
loopIndex = en.index();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// We cannot distinguish atm between unrolled dimensions that implement
|
||||
// the "always full" tile abstraction and need clipping from the other
|
||||
// ones. So we conservatively clip everything.
|
||||
auto N = view.ub(memRefDim);
|
||||
auto i = memRefAccess[memRefDim];
|
||||
if (loopIndex < 0) {
|
||||
clippedScalarAccessExprs[memRefDim] =
|
||||
select(i < zero, zero, select(i < N, i, N - one));
|
||||
} else {
|
||||
auto ii = ivs[loopIndex];
|
||||
clippedScalarAccessExprs[memRefDim] =
|
||||
select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one));
|
||||
}
|
||||
}
|
||||
|
||||
// Create the proper bindables for lbs, ubs and steps. Additionally, if we
|
||||
// recorded a coalescing index, permute the loop informations.
|
||||
auto lbs = makeNewExprs(ivs.size(), this->rewriter->getIndexType());
|
||||
auto ubs = copyExprs(vectorSizes);
|
||||
auto steps = makeNewExprs(ivs.size(), this->rewriter->getIndexType());
|
||||
if (coalescingIndex >= 0) {
|
||||
std::swap(ivs[coalescingIndex], ivs.back());
|
||||
std::swap(lbs[coalescingIndex], lbs.back());
|
||||
std::swap(ubs[coalescingIndex], ubs.back());
|
||||
std::swap(steps[coalescingIndex], steps.back());
|
||||
}
|
||||
emitter
|
||||
.template bindZipRangeConstants<ConstantIndexOp>(
|
||||
llvm::zip(lbs, SmallVector<int64_t, 8>(ivs.size(), 0)))
|
||||
.template bindZipRangeConstants<ConstantIndexOp>(
|
||||
llvm::zip(steps, SmallVector<int64_t, 8>(ivs.size(), 1)));
|
||||
|
||||
return VectorTransferAccessInfo{ivs,
|
||||
copyExprs(lbs),
|
||||
ubs,
|
||||
copyExprs(steps),
|
||||
clippedScalarAccessExprs,
|
||||
tmpAccessExprs};
|
||||
return clippedScalarAccessExprs;
|
||||
}
|
||||
|
||||
template <typename VectorTransferOpTy>
|
||||
VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
|
||||
VectorTransferOpTy *transfer, MLFuncLoweringRewriter *rewriter,
|
||||
MLFuncGlobalLoweringState *state)
|
||||
: transfer(transfer), rewriter(rewriter), state(state),
|
||||
memrefType(transfer->getMemRefType()), memrefShape(memrefType.getShape()),
|
||||
vectorType(transfer->getVectorType()), vectorShape(vectorType.getShape()),
|
||||
permutationMap(transfer->getPermutationMap()),
|
||||
tmpMemRefType(
|
||||
MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0)),
|
||||
vectorMemRefType(MemRefType::get({1}, vectorType, {}, 0)),
|
||||
emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())),
|
||||
vectorSizes(
|
||||
edsc::makeNewExprs(vectorShape.size(), rewriter->getIndexType())),
|
||||
zero(emitter.zero()), one(emitter.one()),
|
||||
scalarMemRef(transfer->getMemRefType()) {
|
||||
// Bind the Bindable.
|
||||
SmallVector<Value *, 8> transferIndices(transfer->getIndices());
|
||||
accesses = edsc::makeNewExprs(transferIndices.size(),
|
||||
this->rewriter->getIndexType());
|
||||
emitter.bind(edsc::Bindable(scalarMemRef), transfer->getMemRef())
|
||||
.template bindZipRangeConstants<ConstantIndexOp>(
|
||||
llvm::zip(vectorSizes, vectorShape))
|
||||
.template bindZipRange(llvm::zip(accesses, transfer->getIndices()));
|
||||
};
|
||||
: transfer(transfer), rewriter(rewriter), state(state){};
|
||||
|
||||
/// Lowers VectorTransferReadOp into a combination of:
|
||||
/// 1. local memory allocation;
|
||||
@@ -313,38 +265,39 @@ VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
|
||||
/// TODO(ntv): support non-data-parallel operations.
|
||||
template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::op;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
|
||||
// Build the AccessInfo which contain all the information needed to build the
|
||||
// perfectly nest loop nest to perform clipped reads and local writes.
|
||||
auto accessInfo = makeVectorTransferAccessInfo();
|
||||
// 1. Setup all the captures.
|
||||
ScopedContext scope(FuncBuilder(transfer->getInstruction()),
|
||||
transfer->getLoc());
|
||||
IndexedValue remote(transfer->getMemRef());
|
||||
MemRefView view(transfer->getMemRef());
|
||||
VectorView vectorView(transfer->getVector());
|
||||
SmallVector<IndexHandle, 8> ivs(vectorView.rank());
|
||||
SmallVector<ValueHandle *, 8> pivs;
|
||||
for (auto &idx : ivs) {
|
||||
pivs.push_back(&idx);
|
||||
}
|
||||
coalesceCopy(transfer, &pivs, &vectorView);
|
||||
|
||||
// clang-format off
|
||||
auto &ivs = accessInfo.ivs;
|
||||
auto &lbs = accessInfo.lowerBoundsExprs;
|
||||
auto &ubs = accessInfo.upperBoundsExprs;
|
||||
auto &steps = accessInfo.stepExprs;
|
||||
auto lbs = vectorView.getLbs();
|
||||
auto ubs = vectorView.getUbs();
|
||||
auto steps = vectorView.getSteps();
|
||||
|
||||
auto vectorType = this->transfer->getVectorType();
|
||||
auto scalarType = this->transfer->getMemRefType().getElementType();
|
||||
|
||||
Expr scalarValue(scalarType), vectorValue(vectorType), tmpAlloc(tmpMemRefType), tmpDealloc(Type{}), vectorView(vectorMemRefType);
|
||||
auto block = edsc::block({
|
||||
tmpAlloc = alloc(tmpMemRefType),
|
||||
vectorView = vector_type_cast(Expr(tmpAlloc), vectorMemRefType),
|
||||
For(ivs, lbs, ubs, steps, {
|
||||
scalarValue = load(scalarMemRef, accessInfo.clippedScalarAccessExprs),
|
||||
store(scalarValue, tmpAlloc, accessInfo.tmpAccessExprs),
|
||||
}),
|
||||
vectorValue = load(vectorView, {zero}),
|
||||
tmpDealloc = dealloc(tmpAlloc)
|
||||
// 2. Emit alloc-copy-load-dealloc.
|
||||
ValueHandle tmp = alloc(tmpMemRefType());
|
||||
IndexedValue local(tmp);
|
||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)({
|
||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||
local(ivs) = remote(clip(transfer, view, ivs)),
|
||||
});
|
||||
// clang-format on
|
||||
ValueHandle vectorValue = LOAD(vec, {index_t(0)});
|
||||
(dealloc(tmp)); // vexing parse
|
||||
|
||||
// Emit the MLIR.
|
||||
emitter.emitStmts(block.getBody());
|
||||
|
||||
// Finalize rewriting.
|
||||
transfer->replaceAllUsesWith(emitter.getValue(vectorValue));
|
||||
// 3. Propagate.
|
||||
transfer->replaceAllUsesWith(vectorValue.getValue());
|
||||
transfer->erase();
|
||||
}
|
||||
|
||||
@@ -368,37 +321,38 @@ template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
|
||||
/// TODO(ntv): support non-data-parallel operations.
|
||||
template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::op;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
|
||||
// Build the AccessInfo which contain all the information needed to build the
|
||||
// perfectly nest loop nest to perform local reads and clipped writes.
|
||||
auto accessInfo = makeVectorTransferAccessInfo();
|
||||
// 1. Setup all the captures.
|
||||
ScopedContext scope(FuncBuilder(transfer->getInstruction()),
|
||||
transfer->getLoc());
|
||||
IndexedValue remote(transfer->getMemRef());
|
||||
MemRefView view(transfer->getMemRef());
|
||||
ValueHandle vectorValue(transfer->getVector());
|
||||
VectorView vectorView(transfer->getVector());
|
||||
SmallVector<IndexHandle, 8> ivs(vectorView.rank());
|
||||
SmallVector<ValueHandle *, 8> pivs;
|
||||
for (auto &idx : ivs) {
|
||||
pivs.push_back(&idx);
|
||||
}
|
||||
coalesceCopy(transfer, &pivs, &vectorView);
|
||||
|
||||
// Bind vector value for the vector_transfer_write.
|
||||
Expr vectorValue(transfer->getVectorType());
|
||||
emitter.bind(Bindable(vectorValue), transfer->getVector());
|
||||
auto lbs = vectorView.getLbs();
|
||||
auto ubs = vectorView.getUbs();
|
||||
auto steps = vectorView.getSteps();
|
||||
|
||||
// clang-format off
|
||||
auto &ivs = accessInfo.ivs;
|
||||
auto &lbs = accessInfo.lowerBoundsExprs;
|
||||
auto &ubs = accessInfo.upperBoundsExprs;
|
||||
auto &steps = accessInfo.stepExprs;
|
||||
auto scalarType = tmpMemRefType.getElementType();
|
||||
Expr scalarValue(scalarType), tmpAlloc(tmpMemRefType), tmpDealloc(Type{}), vectorView(vectorMemRefType);
|
||||
auto block = edsc::block({
|
||||
tmpAlloc = alloc(tmpMemRefType),
|
||||
vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
|
||||
store(vectorValue, vectorView, MutableArrayRef<Expr>{zero}),
|
||||
For(ivs, lbs, ubs, steps, {
|
||||
scalarValue = load(tmpAlloc, accessInfo.tmpAccessExprs),
|
||||
store(scalarValue, scalarMemRef, accessInfo.clippedScalarAccessExprs),
|
||||
}),
|
||||
tmpDealloc = dealloc(tmpAlloc)});
|
||||
// clang-format on
|
||||
// 2. Emit alloc-store-copy-dealloc.
|
||||
ValueHandle tmp = alloc(tmpMemRefType());
|
||||
IndexedValue local(tmp);
|
||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType());
|
||||
STORE(vectorValue, vec, {index_t(0)});
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)({
|
||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||
remote(clip(transfer, view, ivs)) = local(ivs),
|
||||
});
|
||||
(dealloc(tmp)); // vexing parse...
|
||||
|
||||
// Emit the MLIR.
|
||||
emitter.emitStmts(block.getBody());
|
||||
|
||||
// Finalize rewriting.
|
||||
transfer->erase();
|
||||
}
|
||||
|
||||
@@ -430,9 +384,6 @@ struct LowerVectorTransfersPass
|
||||
applyMLPatternsGreedily<VectorTransferExpander<VectorTransferReadOp>,
|
||||
VectorTransferExpander<VectorTransferWriteOp>>(f);
|
||||
}
|
||||
|
||||
// Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit.
|
||||
edsc::ScopedEDSCContext raiiContext;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
@@ -338,10 +338,10 @@ TEST_FUNC(builder_helpers) {
|
||||
step2 = vA.step(2);
|
||||
LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})({
|
||||
LoopBuilder(&k1, lb2, ub2, step2)({
|
||||
C({i, j, k1}) = f7 + A({i, j, k1}) + B({i, j, k1}),
|
||||
C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1),
|
||||
}),
|
||||
LoopBuilder(&k2, lb2, ub2, step2)({
|
||||
C({i, j, k2}) += A({i, j, k2}) + B({i, j, k2}),
|
||||
C(i, j, k2) += A(i, j, k2) + B(i, j, k2),
|
||||
}),
|
||||
});
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: mlir-opt %s -lower-vector-transfers | FileCheck %s
|
||||
|
||||
// CHECK: #[[ADD:map[0-9]+]] = (d0, d1) -> (d0 + d1)
|
||||
// CHECK: #[[SUB:map[0-9]+]] = (d0, d1) -> (d0 - d1)
|
||||
// CHECK: #[[SUB:map[0-9]+]] = (d0) -> (d0 - 1)
|
||||
|
||||
// CHECK-LABEL: func @materialize_read_1d() {
|
||||
func @materialize_read_1d() {
|
||||
@@ -58,49 +58,53 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
|
||||
// CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 {
|
||||
// CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 {
|
||||
// CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 {
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: {{.*}} = dim %0, 0 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: {{.*}} = dim %0, 1 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: {{.*}} = dim %0, 2 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: {{.*}} = dim %0, 3 : memref<?x?x?x?xf32>
|
||||
// CHECK: %[[D0:.*]] = dim %0, 0 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D1:.*]] = dim %0, 1 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D2:.*]] = dim %0, 2 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D3:.*]] = dim %0, 3 : memref<?x?x?x?xf32>
|
||||
// CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32>
|
||||
// CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast %[[ALLOC]] : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>>
|
||||
// CHECK-NEXT: for %[[I4:.*]] = 0 to 3 {
|
||||
// CHECK-NEXT: for %[[I5:.*]] = 0 to 4 {
|
||||
// CHECK-NEXT: for %[[I6:.*]] = 0 to 5 {
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]]
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D0]])
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: %[[L0:.*]] = select
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D1]])
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: %[[L1:.*]] = select
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D2]])
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]]
|
||||
// CHECK-NEXT: %[[L2:.*]] = select
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]]
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = load %0[{{.*}}] : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: store {{.*}}, %[[ALLOC]][{{.*}}] : memref<5x4x3xf32>
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D3]])
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: %[[L3:.*]] = select
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = load %0[%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: store {{.*}}, %[[ALLOC]][%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: {{.*}} = load %[[VECTOR_VIEW]][%[[C0]]] : memref<1xvector<5x4x3xf32>>
|
||||
// CHECK: {{.*}} = load %[[VECTOR_VIEW]][{{.*}}] : memref<1xvector<5x4x3xf32>>
|
||||
// CHECK-NEXT: dealloc %[[ALLOC]] : memref<5x4x3xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
@@ -108,6 +112,10 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT:}
|
||||
|
||||
// Check that I0 + I4 (of size 3) read from first index load(L0, ...) and write into last index store(..., I4)
|
||||
// Check that I3 + I6 (of size 5) read from last index load(..., L3) and write into first index store(I6, ...)
|
||||
// Other dimensions are just accessed with I1, I2 resp.
|
||||
%A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32, 0>
|
||||
for %i0 = 0 to %M step 3 {
|
||||
for %i1 = 0 to %N {
|
||||
@@ -129,49 +137,53 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
|
||||
// CHECK-NEXT: for %[[I1:.*]] = 0 to %arg1 step 4 {
|
||||
// CHECK-NEXT: for %[[I2:.*]] = 0 to %arg2 {
|
||||
// CHECK-NEXT: for %[[I3:.*]] = 0 to %arg3 step 5 {
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: {{.*}} = dim %0, 0 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: {{.*}} = dim %0, 1 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: {{.*}} = dim %0, 2 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: {{.*}} = dim %0, 3 : memref<?x?x?x?xf32>
|
||||
// CHECK: %[[D0:.*]] = dim %0, 0 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D1:.*]] = dim %0, 1 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D2:.*]] = dim %0, 2 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D3:.*]] = dim %0, 3 : memref<?x?x?x?xf32>
|
||||
// CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32>
|
||||
// CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector_type_cast {{.*}} : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>>
|
||||
// CHECK-NEXT: store %cst, {{.*}}[%[[C0]]] : memref<1xvector<5x4x3xf32>>
|
||||
// CHECK: store %cst, {{.*}} : memref<1xvector<5x4x3xf32>>
|
||||
// CHECK-NEXT: for %[[I4:.*]] = 0 to 3 {
|
||||
// CHECK-NEXT: for %[[I5:.*]] = 0 to 4 {
|
||||
// CHECK-NEXT: for %[[I6:.*]] = 0 to 5 {
|
||||
// CHECK-NEXT: {{.*}} = load {{.*}}[%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32>
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]({{.*}}, %[[C1]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D0]])
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
// CHECK-NEXT: %[[S0:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I1]], %[[I5]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I1]], %[[I5]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I1]], %[[I5]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]({{.*}}, %[[C1]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D1]])
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
// CHECK-NEXT: %[[S1:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", %[[I2]], %[[C0]] : index
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", %[[I2]], %3 : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #map{{.*}}(%3, %[[C1]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D2]])
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, %[[I2]], {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
// CHECK-NEXT: %[[S2:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]({{.*}}, %[[C1]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]](%[[D3]])
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
// CHECK-NEXT: store {{.*}}, {{.*}}[{{.*}}, {{.*}}, {{.*}}, {{.*}}] : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[S3:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = load {{.*}}[%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32>
|
||||
// CHECK: store {{.*}}, {{.*}}[%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
@@ -182,6 +194,11 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT:}
|
||||
//
|
||||
// Check that I0 + I4 (of size 3) read from last index load(..., I4) and write into first index store(S0, ...)
|
||||
// Check that I1 + I5 (of size 4) read from second index load(..., I5, ...) and write into second index store(..., S1, ...)
|
||||
// Check that I3 + I6 (of size 5) read from first index load(I6, ...) and write into last index store(..., S3)
|
||||
// Other dimension is just accessed with I2.
|
||||
%A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32, 0>
|
||||
%f1 = constant splat<vector<5x4x3xf32>, 1.000000e+00> : vector<5x4x3xf32>
|
||||
for %i0 = 0 to %M step 3 {
|
||||
|
||||
Reference in New Issue
Block a user