mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
[mlir][Vector] Add more vector.contract -> outerproduct lowerings and fix vector.contract type inference.
This revision expands the types of vector contractions that can be lowered to vector.outerproduct. All 8 permutation cases are support. The idiomatic manipulation of AffineMap written declaratively makes this straightforward. In the process a bug with the vector.contract verifier was uncovered. The vector shape verification part of the contract op is rewritten to use AffineMap composition. One bug in the vector `ops.mlir` test is fixed and a new case not yet captured is added to the vector`invalid.mlir` test. Differential Revision: https://reviews.llvm.org/D80393
This commit is contained in:
@@ -1454,10 +1454,17 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
|
||||
if (llvm::size(op.masks()) != 0)
|
||||
return failure();
|
||||
|
||||
auto iteratorTypes = op.iterator_types().getValue();
|
||||
if (!isParallelIterator(iteratorTypes[0]) ||
|
||||
!isParallelIterator(iteratorTypes[1]) ||
|
||||
!isReductionIterator(iteratorTypes[2]))
|
||||
return failure();
|
||||
|
||||
if (vectorTransformsOptions.vectorContractLowering !=
|
||||
vector::VectorContractLowering::Matmul ||
|
||||
!isRowMajorMatmul(op.indexing_maps()))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1503,34 +1510,8 @@ void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op,
|
||||
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
|
||||
/// ```
|
||||
///
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
|
||||
/// the vector.contract op is a row-major matrix multiply.
|
||||
void ContractionOpToOuterProductOpLowering::rewrite(
|
||||
vector::ContractionOp op, PatternRewriter &rewriter) const {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
// TODO(ntv) other modes.
|
||||
// We know we are in row-major.
|
||||
bool transposeLhs = false;
|
||||
unsigned reductionSize =
|
||||
transposeLhs ? lhsType.getShape()[0] : lhsType.getShape()[1];
|
||||
|
||||
// If transposeLhs == false (i.e. lhs(m, reductionSize)), we need to
|
||||
// transpose it to extract the proper vector<m x f32>. Otherwise, just take
|
||||
// the lhs.
|
||||
Value lhs = transposeLhs
|
||||
? op.lhs()
|
||||
: rewriter.create<vector::TransposeOp>(
|
||||
op.getLoc(), op.lhs(), ArrayRef<int64_t>{1, 0});
|
||||
Value res = op.acc();
|
||||
// ExtractOp does not allow dynamic indexing, we must unroll explicitly.
|
||||
for (unsigned k = 0; k < reductionSize; ++k) {
|
||||
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
|
||||
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), op.rhs(), k);
|
||||
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
|
||||
}
|
||||
rewriter.replaceOp(op, res);
|
||||
}
|
||||
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
|
||||
/// otherwise supports any layout permutation of the matrix-multiply.
|
||||
LogicalResult
|
||||
ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
|
||||
// TODO(ajcbik): implement masks
|
||||
@@ -1538,12 +1519,104 @@ ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const {
|
||||
return failure();
|
||||
|
||||
if (vectorTransformsOptions.vectorContractLowering !=
|
||||
vector::VectorContractLowering::OuterProduct ||
|
||||
!isRowMajorMatmul(op.indexing_maps()))
|
||||
vector::VectorContractLowering::OuterProduct)
|
||||
return failure();
|
||||
|
||||
// Transpose arguments to make them ready for lowering to OuterProduct. The
|
||||
// constraint to match is that we must load full rows at a time with
|
||||
// vector::ExtractOp.
|
||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
|
||||
AffineExpr m, n, k;
|
||||
bindDims(op.getContext(), m, n, k);
|
||||
auto iteratorTypes = op.iterator_types().getValue();
|
||||
if (!isParallelIterator(iteratorTypes[0]) ||
|
||||
!isParallelIterator(iteratorTypes[1]) ||
|
||||
!isReductionIterator(iteratorTypes[2]))
|
||||
return failure();
|
||||
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
|
||||
// When lowering to outerproduct we can support all permutations.
|
||||
if (maps != infer({{m, k}, {k, n}, {m, n}}) &&
|
||||
maps != infer({{m, k}, {n, k}, {m, n}}) &&
|
||||
maps != infer({{k, m}, {k, n}, {m, n}}) &&
|
||||
maps != infer({{k, m}, {n, k}, {m, n}}) &&
|
||||
maps != infer({{m, k}, {k, n}, {n, m}}) &&
|
||||
maps != infer({{m, k}, {n, k}, {n, m}}) &&
|
||||
maps != infer({{k, m}, {k, n}, {n, m}}) &&
|
||||
maps != infer({{k, m}, {n, k}, {n, m}}))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void ContractionOpToOuterProductOpLowering::rewrite(
|
||||
vector::ContractionOp op, PatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
unsigned reductionSize = 0;
|
||||
VectorType lhsType = op.getLhsType();
|
||||
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
|
||||
|
||||
// Transpose arguments to make them ready for lowering to OuterProduct. The
|
||||
// constraint to match is that we must load full rows at a time with
|
||||
// vector::ExtractOp.
|
||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
|
||||
AffineExpr m, n, k;
|
||||
bindDims(rewriter.getContext(), m, n, k);
|
||||
SmallVector<int64_t, 2> perm{1, 0};
|
||||
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
|
||||
// First batch of cases, no need to output permute.
|
||||
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
|
||||
// This is the classical row-major matmul. Just permute the lhs.
|
||||
reductionSize = lhsType.getShape()[1];
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
} else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
|
||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||
reductionSize = lhsType.getShape()[1];
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
|
||||
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
|
||||
// No need to permute anything.
|
||||
reductionSize = lhsType.getShape()[0];
|
||||
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
|
||||
// Just permute the rhs.
|
||||
reductionSize = lhsType.getShape()[0];
|
||||
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
|
||||
}
|
||||
// Second batch of cases, reshuffle to avoid output permute.
|
||||
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
|
||||
// This is the classical row-major matmul. Just permute the lhs.
|
||||
reductionSize = lhsType.getShape()[1];
|
||||
Value tmp = rhs;
|
||||
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
lhs = tmp;
|
||||
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
|
||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||
reductionSize = lhsType.getShape()[1];
|
||||
Value tmp = rhs;
|
||||
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
|
||||
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
|
||||
// No need to permute anything, but still swap lhs and rhs.
|
||||
reductionSize = lhsType.getShape()[0];
|
||||
std::swap(lhs, rhs);
|
||||
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
|
||||
// Just permute the rhs.
|
||||
reductionSize = lhsType.getShape()[0];
|
||||
Value tmp = lhs;
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
|
||||
rhs = tmp;
|
||||
}
|
||||
assert(reductionSize > 0);
|
||||
|
||||
// ExtractOp does not allow dynamic indexing, we must unroll explicitly.
|
||||
for (unsigned k = 0; k < reductionSize; ++k) {
|
||||
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
|
||||
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
|
||||
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
|
||||
}
|
||||
rewriter.replaceOp(op, res);
|
||||
}
|
||||
|
||||
/// Progressive lowering of ContractionOp.
|
||||
/// One:
|
||||
/// %x = vector.contract with at least one free/batch dimension
|
||||
|
||||
Reference in New Issue
Block a user