[mlir][Vector] Add more vector.contract -> outerproduct lowerings and fix vector.contract type inference.

This revision expands the types of vector contractions that can be lowered to vector.outerproduct.
All 8 permutation cases are support.
The idiomatic manipulation of AffineMap written declaratively makes this straightforward.

In the process a bug with the vector.contract verifier was uncovered.
The vector shape verification part of the contract op is rewritten to use AffineMap composition.
One bug in the vector `ops.mlir` test is fixed and a new case not yet captured is added
to the vector`invalid.mlir` test.

Differential Revision: https://reviews.llvm.org/D80393
This commit is contained in:
Nicolas Vasilache
2020-05-26 15:34:57 -04:00
parent e99d50d844
commit ba10daa820
6 changed files with 407 additions and 53 deletions

View File

@@ -1454,10 +1454,17 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
if (llvm::size(op.masks()) != 0)
return failure();
auto iteratorTypes = op.iterator_types().getValue();
if (!isParallelIterator(iteratorTypes[0]) ||
!isParallelIterator(iteratorTypes[1]) ||
!isReductionIterator(iteratorTypes[2]))
return failure();
if (vectorTransformsOptions.vectorContractLowering !=
vector::VectorContractLowering::Matmul ||
!isRowMajorMatmul(op.indexing_maps()))
return failure();
return success();
}
@@ -1503,34 +1510,8 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
void ContractionOpToOuterProductOpLowering::rewrite(
vector::ContractionOp op, PatternRewriter &rewriter) const {
VectorType lhsType = op.getLhsType();
// TODO(ntv) other modes.
// We know we are in row-major.
bool transposeLhs = false;
unsigned reductionSize =
transposeLhs ? lhsType.getShape()[0] : lhsType.getShape()[1];
// If transposeLhs == false (i.e. lhs(m, reductionSize)), we need to
// transpose it to extract the proper vector<m x f32>. Otherwise, just take
// the lhs.
Value lhs = transposeLhs
? op.lhs()
: rewriter.create<vector::TransposeOp>(
op.getLoc(), op.lhs(), ArrayRef<int64_t>{1, 0});
Value res = op.acc();
// ExtractOp does not allow dynamic indexing, we must unroll explicitly.
for (unsigned k = 0; k < reductionSize; ++k) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), op.rhs(), k);
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
}
rewriter.replaceOp(op, res);
}
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
/// otherwise supports any layout permutation of the matrix-multiply.
LogicalResult
ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
// TODO(ajcbik): implement masks
@@ -1538,12 +1519,104 @@ ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
return failure();
if (vectorTransformsOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct ||
!isRowMajorMatmul(op.indexing_maps()))
vector::VectorContractLowering::OuterProduct)
return failure();
// Transpose arguments to make them ready for lowering to OuterProduct. The
// constraint to match is that we must load full rows at a time with
// vector::ExtractOp.
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr m, n, k;
bindDims(op.getContext(), m, n, k);
auto iteratorTypes = op.iterator_types().getValue();
if (!isParallelIterator(iteratorTypes[0]) ||
!isParallelIterator(iteratorTypes[1]) ||
!isReductionIterator(iteratorTypes[2]))
return failure();
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
// When lowering to outerproduct we can support all permutations.
if (maps != infer({{m, k}, {k, n}, {m, n}}) &&
maps != infer({{m, k}, {n, k}, {m, n}}) &&
maps != infer({{k, m}, {k, n}, {m, n}}) &&
maps != infer({{k, m}, {n, k}, {m, n}}) &&
maps != infer({{m, k}, {k, n}, {n, m}}) &&
maps != infer({{m, k}, {n, k}, {n, m}}) &&
maps != infer({{k, m}, {k, n}, {n, m}}) &&
maps != infer({{k, m}, {n, k}, {n, m}}))
return failure();
return success();
}
void ContractionOpToOuterProductOpLowering::rewrite(
vector::ContractionOp op, PatternRewriter &rewriter) const {
Location loc = op.getLoc();
unsigned reductionSize = 0;
VectorType lhsType = op.getLhsType();
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
// Transpose arguments to make them ready for lowering to OuterProduct. The
// constraint to match is that we must load full rows at a time with
// vector::ExtractOp.
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
SmallVector<int64_t, 2> perm{1, 0};
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
// First batch of cases, no need to output permute.
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
// This is the classical row-major matmul. Just permute the lhs.
reductionSize = lhsType.getShape()[1];
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
reductionSize = lhsType.getShape()[1];
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
// No need to permute anything.
reductionSize = lhsType.getShape()[0];
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
// Just permute the rhs.
reductionSize = lhsType.getShape()[0];
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
}
// Second batch of cases, reshuffle to avoid output permute.
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
// This is the classical row-major matmul. Just permute the lhs.
reductionSize = lhsType.getShape()[1];
Value tmp = rhs;
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
lhs = tmp;
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
reductionSize = lhsType.getShape()[1];
Value tmp = rhs;
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
// No need to permute anything, but still swap lhs and rhs.
reductionSize = lhsType.getShape()[0];
std::swap(lhs, rhs);
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
// Just permute the rhs.
reductionSize = lhsType.getShape()[0];
Value tmp = lhs;
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
rhs = tmp;
}
assert(reductionSize > 0);
// ExtractOp does not allow dynamic indexing, we must unroll explicitly.
for (unsigned k = 0; k < reductionSize; ++k) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
}
rewriter.replaceOp(op, res);
}
/// Progressive lowering of ContractionOp.
/// One:
/// %x = vector.contract with at least one free/batch dimension