mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[mlir][Vector] Add vector contraction to outerproduct lowering
This revision adds the additional lowering and exposes the patterns at a finer granularity for better programmatic reuse. The unit test makes use of the finer grained pattern for simpler checks. As the ContractionOpLowering is exposed programmatically, cleanup opportunities appear and static class methods are turned into free functions with static visibility. Differential Revision: https://reviews.llvm.org/D80375
This commit is contained in:
@@ -39,6 +39,120 @@
|
||||
using namespace mlir;
|
||||
using llvm::dbgs;
|
||||
|
||||
// Helper to find an index in an affine map.
|
||||
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
||||
if (idx == index)
|
||||
return i;
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
// Helper to construct iterator types with one index removed.
|
||||
static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
|
||||
int64_t index) {
|
||||
SmallVector<Attribute, 4> results;
|
||||
for (auto it : llvm::enumerate(iteratorTypes)) {
|
||||
int64_t idx = it.index();
|
||||
if (idx == index)
|
||||
continue;
|
||||
results.push_back(it.value());
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// Helper to construct an affine map with one index removed.
|
||||
static AffineMap adjustMap(AffineMap map, int64_t index,
|
||||
PatternRewriter &rewriter) {
|
||||
auto *ctx = rewriter.getContext();
|
||||
SmallVector<AffineExpr, 4> results;
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
||||
if (idx == index)
|
||||
continue;
|
||||
// Re-insert remaining indices, but renamed when occurring
|
||||
// after the removed index.
|
||||
auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
|
||||
results.push_back(targetExpr);
|
||||
}
|
||||
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
|
||||
}
|
||||
|
||||
// Helper to drop dimension from vector type.
|
||||
static Type adjustType(VectorType tp, int64_t index) {
|
||||
int64_t rank = tp.getRank();
|
||||
Type eltType = tp.getElementType();
|
||||
if (rank == 1) {
|
||||
assert(index == 0 && "index for scalar result out of bounds");
|
||||
return eltType;
|
||||
}
|
||||
SmallVector<int64_t, 4> adjustedShape;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
// Omit dimension at the given index.
|
||||
if (i == index)
|
||||
continue;
|
||||
// Otherwise, add dimension back.
|
||||
adjustedShape.push_back(tp.getDimSize(i));
|
||||
}
|
||||
return VectorType::get(adjustedShape, eltType);
|
||||
}
|
||||
|
||||
// Helper method to possibly drop a dimension in a load.
|
||||
// TODO(ajcbik): use a reshaping vector load (and share lowering code)
|
||||
static Value reshapeLoad(Location loc, Value val, VectorType type,
|
||||
int64_t index, int64_t pos,
|
||||
PatternRewriter &rewriter) {
|
||||
if (index == -1)
|
||||
return val;
|
||||
Type lowType = adjustType(type, 0);
|
||||
// At extraction dimension?
|
||||
if (index == 0) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(pos);
|
||||
return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
|
||||
}
|
||||
// Unroll leading dimensions.
|
||||
VectorType vType = lowType.cast<VectorType>();
|
||||
VectorType resType = adjustType(type, index).cast<VectorType>();
|
||||
Value result =
|
||||
rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
|
||||
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
|
||||
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, resType, load, result, posAttr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helper method to possibly drop a dimension in a store.
|
||||
// TODO(ajcbik): use a reshaping vector store (and share lowering code)
|
||||
static Value reshapeStore(Location loc, Value val, Value result,
|
||||
VectorType type, int64_t index, int64_t pos,
|
||||
PatternRewriter &rewriter) {
|
||||
// Unmodified?
|
||||
if (index == -1)
|
||||
return val;
|
||||
// At insertion dimension?
|
||||
if (index == 0) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(pos);
|
||||
return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
|
||||
}
|
||||
// Unroll leading dimensions.
|
||||
Type lowType = adjustType(type, 0);
|
||||
VectorType vType = lowType.cast<VectorType>();
|
||||
Type insType = adjustType(vType, 0);
|
||||
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
|
||||
Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
|
||||
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
|
||||
result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Clones `op` into a new operations that takes `operands` and returns
|
||||
// `resultTypes`.
|
||||
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
|
||||
@@ -1252,343 +1366,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of ContractionOp.
|
||||
/// One:
|
||||
/// %x = vector.contract with at least one free/batch dimension
|
||||
/// is replaced by:
|
||||
/// %a = vector.contract with one less free/batch dimension
|
||||
/// %b = vector.contract with one less free/batch dimension
|
||||
/// ..
|
||||
/// %x = combine %a %b ..
|
||||
/// until a pure contraction is reached (no free/batch dimensions),
|
||||
/// which is replaced by a fma/reduction op.
|
||||
///
|
||||
/// TODO(ajcbik): break down into transpose/reshape/cast ops
|
||||
/// when they become available to avoid code dup
|
||||
/// TODO(ajcbik): investigate lowering order impact on performance
|
||||
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
|
||||
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
|
||||
MLIRContext *context)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformsOptions(vectorTransformsOptions) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(ajcbik): implement masks
|
||||
if (llvm::size(op.masks()) != 0)
|
||||
return failure();
|
||||
|
||||
// TODO(ntv, ajcbik): implement benefits, cost models, separate this out in
|
||||
// a new pattern.
|
||||
if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics &&
|
||||
isRowMajorMatmul(op.indexing_maps())) {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
unsigned lhsRows = op.getLhsType().getShape()[0];
|
||||
unsigned lhsColumns = op.getLhsType().getShape()[1];
|
||||
unsigned rhsColumns = op.getRhsType().getShape()[1];
|
||||
|
||||
Type flattenedLHSType =
|
||||
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
|
||||
Type flattenedRHSType =
|
||||
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
|
||||
auto lhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedLHSType, op.lhs());
|
||||
auto rhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedRHSType, op.rhs());
|
||||
|
||||
Value mul = rewriter.create<vector::MatmulOp>(
|
||||
op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
|
||||
mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
|
||||
op.acc().getType(), mul);
|
||||
Type elementType = op.getLhsType().getElementType();
|
||||
assert(elementType.isIntOrFloat());
|
||||
if (elementType.isa<IntegerType>())
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Find first batch dimension in LHS/RHS, and lower when found.
|
||||
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
|
||||
if (!batchDimMap.empty()) {
|
||||
int64_t lhsIndex = batchDimMap[0].first;
|
||||
int64_t rhsIndex = batchDimMap[0].second;
|
||||
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Collect contracting dimensions.
|
||||
std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
|
||||
op.getContractingDimMap();
|
||||
DenseSet<int64_t> lhsContractingDimSet;
|
||||
DenseSet<int64_t> rhsContractingDimSet;
|
||||
for (auto &dimPair : contractingDimMap) {
|
||||
lhsContractingDimSet.insert(dimPair.first);
|
||||
rhsContractingDimSet.insert(dimPair.second);
|
||||
}
|
||||
|
||||
// Find first free dimension in LHS, and lower when found.
|
||||
VectorType lhsType = op.getLhsType();
|
||||
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e;
|
||||
++lhsIndex) {
|
||||
if (lhsContractingDimSet.count(lhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
// Find first free dimension in RHS, and lower when found.
|
||||
VectorType rhsType = op.getRhsType();
|
||||
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e;
|
||||
++rhsIndex) {
|
||||
if (rhsContractingDimSet.count(rhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
// Lower the first remaining reduction dimension.
|
||||
if (!contractingDimMap.empty()) {
|
||||
rewriter.replaceOp(op, lowerReduction(op, rewriter));
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
// Lower one parallel dimension.
|
||||
// TODO(ajcbik): consider reusing existing contract unrolling
|
||||
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
|
||||
int64_t rhsIndex, PatternRewriter &rewriter) const {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
VectorType resType = op.getResultType().cast<VectorType>();
|
||||
// Find the iterator type index and result index.
|
||||
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
|
||||
int64_t iterIndex = -1;
|
||||
int64_t dimSize = -1;
|
||||
if (lhsIndex >= 0) {
|
||||
iterIndex =
|
||||
iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
|
||||
assert((rhsIndex < 0 || iterIndex == iMap[1]
|
||||
.getResult(rhsIndex)
|
||||
.cast<AffineDimExpr>()
|
||||
.getPosition()) &&
|
||||
"parallel index should be free in LHS or batch in LHS/RHS");
|
||||
dimSize = lhsType.getDimSize(lhsIndex);
|
||||
} else {
|
||||
assert(rhsIndex >= 0 && "missing parallel index");
|
||||
iterIndex =
|
||||
iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
|
||||
dimSize = rhsType.getDimSize(rhsIndex);
|
||||
}
|
||||
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
|
||||
Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
|
||||
assert(lookup.hasValue() && "parallel index not listed in reduction");
|
||||
int64_t resIndex = lookup.getValue();
|
||||
// Construct new iterator types and affine map array attribute.
|
||||
SmallVector<AffineMap, 4> lowIndexingMaps;
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
|
||||
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
|
||||
auto lowIter =
|
||||
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
|
||||
// Unroll into a series of lower dimensional vector.contract ops.
|
||||
Location loc = op.getLoc();
|
||||
Value result = rewriter.create<ConstantOp>(loc, resType,
|
||||
rewriter.getZeroAttr(resType));
|
||||
for (int64_t d = 0; d < dimSize; ++d) {
|
||||
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
|
||||
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
|
||||
auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
|
||||
Value lowContract = rewriter.create<vector::ContractionOp>(
|
||||
loc, lhs, rhs, acc, lowAffine, lowIter);
|
||||
result = reshapeStore(loc, lowContract, result, resType, resIndex, d,
|
||||
rewriter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Lower one reduction dimension.
|
||||
Value lowerReduction(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
Type resType = op.getResultType();
|
||||
assert(!resType.isa<VectorType>());
|
||||
// Use iterator index 0.
|
||||
int64_t iterIndex = 0;
|
||||
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
|
||||
Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
|
||||
Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
|
||||
assert(lookupLhs.hasValue() && "missing LHS parallel index");
|
||||
assert(lookupRhs.hasValue() && "missing RHS parallel index");
|
||||
int64_t lhsIndex = lookupLhs.getValue();
|
||||
int64_t rhsIndex = lookupRhs.getValue();
|
||||
int64_t dimSize = lhsType.getDimSize(lhsIndex);
|
||||
assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
|
||||
// Base case.
|
||||
if (lhsType.getRank() == 1) {
|
||||
assert(rhsType.getRank() == 1 && "corrupt contraction");
|
||||
Value zero = rewriter.create<ConstantOp>(loc, lhsType,
|
||||
rewriter.getZeroAttr(lhsType));
|
||||
Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
|
||||
StringAttr kind = rewriter.getStringAttr("add");
|
||||
return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
|
||||
op.acc());
|
||||
}
|
||||
// Construct new iterator types and affine map array attribute.
|
||||
SmallVector<AffineMap, 4> lowIndexingMaps;
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
|
||||
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
|
||||
auto lowIter =
|
||||
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
|
||||
// Unroll into a series of lower dimensional vector.contract ops.
|
||||
// By feeding the initial accumulator into the first contraction,
|
||||
// and the result of each contraction into the next, eventually
|
||||
// the sum of all reductions is computed.
|
||||
Value result = op.acc();
|
||||
for (int64_t d = 0; d < dimSize; ++d) {
|
||||
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
|
||||
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
|
||||
result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
|
||||
lowAffine, lowIter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helper to find an index in an affine map.
|
||||
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
||||
if (idx == index)
|
||||
return i;
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
// Helper to construct iterator types with one index removed.
|
||||
static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
|
||||
int64_t index) {
|
||||
SmallVector<Attribute, 4> results;
|
||||
for (auto it : llvm::enumerate(iteratorTypes)) {
|
||||
int64_t idx = it.index();
|
||||
if (idx == index)
|
||||
continue;
|
||||
results.push_back(it.value());
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// Helper to construct an affine map with one index removed.
|
||||
static AffineMap adjustMap(AffineMap map, int64_t index,
|
||||
PatternRewriter &rewriter) {
|
||||
auto *ctx = rewriter.getContext();
|
||||
SmallVector<AffineExpr, 4> results;
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
||||
if (idx == index)
|
||||
continue;
|
||||
// Re-insert remaining indices, but renamed when occurring
|
||||
// after the removed index.
|
||||
auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
|
||||
results.push_back(targetExpr);
|
||||
}
|
||||
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
|
||||
}
|
||||
|
||||
// Helper to drop dimension from vector type.
|
||||
static Type adjustType(VectorType tp, int64_t index) {
|
||||
int64_t rank = tp.getRank();
|
||||
Type eltType = tp.getElementType();
|
||||
if (rank == 1) {
|
||||
assert(index == 0 && "index for scalar result out of bounds");
|
||||
return eltType;
|
||||
}
|
||||
SmallVector<int64_t, 4> adjustedShape;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
// Omit dimension at the given index.
|
||||
if (i == index)
|
||||
continue;
|
||||
// Otherwise, add dimension back.
|
||||
adjustedShape.push_back(tp.getDimSize(i));
|
||||
}
|
||||
return VectorType::get(adjustedShape, eltType);
|
||||
}
|
||||
|
||||
// Helper method to possibly drop a dimension in a load.
|
||||
// TODO(ajcbik): use a reshaping vector load (and share lowering code)
|
||||
static Value reshapeLoad(Location loc, Value val, VectorType type,
|
||||
int64_t index, int64_t pos,
|
||||
PatternRewriter &rewriter) {
|
||||
if (index == -1)
|
||||
return val;
|
||||
Type lowType = adjustType(type, 0);
|
||||
// At extraction dimension?
|
||||
if (index == 0) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(pos);
|
||||
return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
|
||||
}
|
||||
// Unroll leading dimensions.
|
||||
VectorType vType = lowType.cast<VectorType>();
|
||||
VectorType resType = adjustType(type, index).cast<VectorType>();
|
||||
Value result = rewriter.create<ConstantOp>(loc, resType,
|
||||
rewriter.getZeroAttr(resType));
|
||||
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
|
||||
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
|
||||
result = rewriter.create<vector::InsertOp>(loc, resType, load, result,
|
||||
posAttr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helper method to possibly drop a dimension in a store.
|
||||
// TODO(ajcbik): use a reshaping vector store (and share lowering code)
|
||||
static Value reshapeStore(Location loc, Value val, Value result,
|
||||
VectorType type, int64_t index, int64_t pos,
|
||||
PatternRewriter &rewriter) {
|
||||
// Unmodified?
|
||||
if (index == -1)
|
||||
return val;
|
||||
// At insertion dimension?
|
||||
if (index == 0) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(pos);
|
||||
return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
|
||||
}
|
||||
// Unroll leading dimensions.
|
||||
Type lowType = adjustType(type, 0);
|
||||
VectorType vType = lowType.cast<VectorType>();
|
||||
Type insType = adjustType(vType, 0);
|
||||
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||
Value ext =
|
||||
rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
|
||||
Value ins =
|
||||
rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
|
||||
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
vector::VectorTransformsOptions vectorTransformsOptions;
|
||||
};
|
||||
|
||||
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
|
||||
/// vectors progressively on the way to target llvm.matrix intrinsics.
|
||||
/// This iterates over the most major dimension of the 2-D vector and performs
|
||||
@@ -1656,6 +1433,302 @@ public:
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to:
|
||||
/// ```
|
||||
/// %flattened_a = vector.shape_cast %a
|
||||
/// %flattened_b = vector.shape_cast %b
|
||||
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
|
||||
/// %d = vector.shape_cast %%flattened_d
|
||||
/// %e = add %c, %d
|
||||
/// ```
|
||||
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
|
||||
//
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
|
||||
/// the vector.contract op is a row-major matrix multiply.
|
||||
LogicalResult
|
||||
ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
|
||||
// TODO(ajcbik): implement masks
|
||||
if (llvm::size(op.masks()) != 0)
|
||||
return failure();
|
||||
|
||||
if (vectorTransformsOptions.vectorContractLowering !=
|
||||
vector::VectorContractLowering::Matmul ||
|
||||
!isRowMajorMatmul(op.indexing_maps()))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
unsigned lhsRows = op.getLhsType().getShape()[0];
|
||||
unsigned lhsColumns = op.getLhsType().getShape()[1];
|
||||
unsigned rhsColumns = op.getRhsType().getShape()[1];
|
||||
|
||||
Type flattenedLHSType =
|
||||
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
|
||||
Type flattenedRHSType =
|
||||
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
|
||||
auto lhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedLHSType,
|
||||
op.lhs());
|
||||
auto rhs = rewriter.create<vector::ShapeCastOp>(op.getLoc(), flattenedRHSType,
|
||||
op.rhs());
|
||||
|
||||
Value mul = rewriter.create<vector::MatmulOp>(op.getLoc(), lhs, rhs, lhsRows,
|
||||
lhsColumns, rhsColumns);
|
||||
mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(), op.acc().getType(),
|
||||
mul);
|
||||
Type elementType = op.getLhsType().getElementType();
|
||||
assert(elementType.isIntOrFloat());
|
||||
if (elementType.isa<IntegerType>())
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
|
||||
}
|
||||
|
||||
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to a reduction_size-unrolled sequence:
|
||||
/// ```
|
||||
/// %at = vector.transpose %a, [1, 0]
|
||||
/// %bRow0 = vector.extract %b[0]
|
||||
/// %atRow0 = vector.extract %at[0]
|
||||
/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
|
||||
/// ...
|
||||
/// %bRowK = vector.extract %b[K]
|
||||
/// %atRowK = vector.extract %at[K]
|
||||
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
|
||||
/// ```
|
||||
///
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
|
||||
/// the vector.contract op is a row-major matrix multiply.
|
||||
void ContractionOpToOuterProductOpLowering::rewrite(
|
||||
vector::ContractionOp op, PatternRewriter &rewriter) const {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
// TODO(ntv) other modes.
|
||||
// We know we are in row-major.
|
||||
bool transposeLhs = false;
|
||||
unsigned reductionSize =
|
||||
transposeLhs ? lhsType.getShape()[0] : lhsType.getShape()[1];
|
||||
|
||||
// If transposeLhs == false (i.e. lhs(m, reductionSize)), we need to
|
||||
// transpose it to extract the proper vector<m x f32>. Otherwise, just take
|
||||
// the lhs.
|
||||
Value lhs = transposeLhs
|
||||
? op.lhs()
|
||||
: rewriter.create<vector::TransposeOp>(
|
||||
op.getLoc(), op.lhs(), ArrayRef<int64_t>{1, 0});
|
||||
Value res = op.acc();
|
||||
// ExtractOp does not allow dynamic indexing, we must unroll explicitly.
|
||||
for (unsigned k = 0; k < reductionSize; ++k) {
|
||||
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
|
||||
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), op.rhs(), k);
|
||||
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
|
||||
}
|
||||
rewriter.replaceOp(op, res);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
|
||||
// TODO(ajcbik): implement masks
|
||||
if (llvm::size(op.masks()) != 0)
|
||||
return failure();
|
||||
|
||||
if (vectorTransformsOptions.vectorContractLowering !=
|
||||
vector::VectorContractLowering::OuterProduct ||
|
||||
!isRowMajorMatmul(op.indexing_maps()))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Progressive lowering of ContractionOp.
|
||||
/// One:
|
||||
/// %x = vector.contract with at least one free/batch dimension
|
||||
/// is replaced by:
|
||||
/// %a = vector.contract with one less free/batch dimension
|
||||
/// %b = vector.contract with one less free/batch dimension
|
||||
/// ..
|
||||
/// %x = combine %a %b ..
|
||||
/// until a pure contraction is reached (no free/batch dimensions),
|
||||
/// which is replaced by a fma/reduction op.
|
||||
///
|
||||
/// TODO(ajcbik): break down into transpose/reshape/cast ops
|
||||
/// when they become available to avoid code dup
|
||||
/// TODO(ajcbik): investigate lowering order impact on performance
|
||||
LogicalResult
|
||||
ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
|
||||
// TODO(ajcbik): implement masks.
|
||||
if (llvm::size(op.masks()) != 0)
|
||||
return failure();
|
||||
|
||||
// TODO(ntv, ajcbik): implement benefits, cost models.
|
||||
MLIRContext *ctx = op.getContext();
|
||||
ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx);
|
||||
if (succeeded(pat1.match(op)))
|
||||
return failure();
|
||||
ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx);
|
||||
if (succeeded(pat2.match(op)))
|
||||
return failure();
|
||||
|
||||
// Find first batch dimension in LHS/RHS, and lower when found.
|
||||
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
|
||||
if (!batchDimMap.empty()) {
|
||||
int64_t lhsIndex = batchDimMap[0].first;
|
||||
int64_t rhsIndex = batchDimMap[0].second;
|
||||
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Collect contracting dimensions.
|
||||
std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
|
||||
op.getContractingDimMap();
|
||||
DenseSet<int64_t> lhsContractingDimSet;
|
||||
DenseSet<int64_t> rhsContractingDimSet;
|
||||
for (auto &dimPair : contractingDimMap) {
|
||||
lhsContractingDimSet.insert(dimPair.first);
|
||||
rhsContractingDimSet.insert(dimPair.second);
|
||||
}
|
||||
|
||||
// Find first free dimension in LHS, and lower when found.
|
||||
VectorType lhsType = op.getLhsType();
|
||||
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
|
||||
if (lhsContractingDimSet.count(lhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
// Find first free dimension in RHS, and lower when found.
|
||||
VectorType rhsType = op.getRhsType();
|
||||
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
|
||||
if (rhsContractingDimSet.count(rhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
// Lower the first remaining reduction dimension.
|
||||
if (!contractingDimMap.empty()) {
|
||||
rewriter.replaceOp(op, lowerReduction(op, rewriter));
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Lower one parallel dimension.
|
||||
// TODO(ajcbik): consider reusing existing contract unrolling
|
||||
Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
|
||||
int64_t lhsIndex, int64_t rhsIndex,
|
||||
PatternRewriter &rewriter) const {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
VectorType resType = op.getResultType().cast<VectorType>();
|
||||
// Find the iterator type index and result index.
|
||||
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
|
||||
int64_t iterIndex = -1;
|
||||
int64_t dimSize = -1;
|
||||
if (lhsIndex >= 0) {
|
||||
iterIndex = iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
|
||||
assert(
|
||||
(rhsIndex < 0 ||
|
||||
iterIndex ==
|
||||
iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition()) &&
|
||||
"parallel index should be free in LHS or batch in LHS/RHS");
|
||||
dimSize = lhsType.getDimSize(lhsIndex);
|
||||
} else {
|
||||
assert(rhsIndex >= 0 && "missing parallel index");
|
||||
iterIndex = iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
|
||||
dimSize = rhsType.getDimSize(rhsIndex);
|
||||
}
|
||||
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
|
||||
Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
|
||||
assert(lookup.hasValue() && "parallel index not listed in reduction");
|
||||
int64_t resIndex = lookup.getValue();
|
||||
// Construct new iterator types and affine map array attribute.
|
||||
SmallVector<AffineMap, 4> lowIndexingMaps;
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
|
||||
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
|
||||
auto lowIter =
|
||||
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
|
||||
// Unroll into a series of lower dimensional vector.contract ops.
|
||||
Location loc = op.getLoc();
|
||||
Value result =
|
||||
rewriter.create<ConstantOp>(loc, resType, rewriter.getZeroAttr(resType));
|
||||
for (int64_t d = 0; d < dimSize; ++d) {
|
||||
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
|
||||
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
|
||||
auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter);
|
||||
Value lowContract = rewriter.create<vector::ContractionOp>(
|
||||
loc, lhs, rhs, acc, lowAffine, lowIter);
|
||||
result =
|
||||
reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Lower one reduction dimension.
|
||||
Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
Type resType = op.getResultType();
|
||||
assert(!resType.isa<VectorType>());
|
||||
// Use iterator index 0.
|
||||
int64_t iterIndex = 0;
|
||||
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
|
||||
Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
|
||||
Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
|
||||
assert(lookupLhs.hasValue() && "missing LHS parallel index");
|
||||
assert(lookupRhs.hasValue() && "missing RHS parallel index");
|
||||
int64_t lhsIndex = lookupLhs.getValue();
|
||||
int64_t rhsIndex = lookupRhs.getValue();
|
||||
int64_t dimSize = lhsType.getDimSize(lhsIndex);
|
||||
assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
|
||||
// Base case.
|
||||
if (lhsType.getRank() == 1) {
|
||||
assert(rhsType.getRank() == 1 && "corrupt contraction");
|
||||
Value zero = rewriter.create<ConstantOp>(loc, lhsType,
|
||||
rewriter.getZeroAttr(lhsType));
|
||||
Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
|
||||
StringAttr kind = rewriter.getStringAttr("add");
|
||||
return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
|
||||
op.acc());
|
||||
}
|
||||
// Construct new iterator types and affine map array attribute.
|
||||
SmallVector<AffineMap, 4> lowIndexingMaps;
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
|
||||
lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
|
||||
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
|
||||
auto lowIter =
|
||||
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
|
||||
// Unroll into a series of lower dimensional vector.contract ops.
|
||||
// By feeding the initial accumulator into the first contraction,
|
||||
// and the result of each contraction into the next, eventually
|
||||
// the sum of all reductions is computed.
|
||||
Value result = op.acc();
|
||||
for (int64_t d = 0; d < dimSize; ++d) {
|
||||
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
|
||||
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
|
||||
result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
|
||||
lowAffine, lowIter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
|
||||
// TODO(andydavis) Add this as DRR pattern.
|
||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||
@@ -1685,6 +1758,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
|
||||
ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern,
|
||||
TransposeOpLowering>(context);
|
||||
patterns.insert<ContractionOpLowering,
|
||||
ContractionOpToMatmulOpLowering,
|
||||
ContractionOpToOuterProductOpLowering>(parameters, context);
|
||||
// clang-format on
|
||||
patterns.insert<ContractionOpLowering>(parameters, context);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user