[mlir] [vector] Add an optional filter to vector contract lowering patterns.

Summary: Vector contract patterns were only parameterized by a `vectorTransformsOptions`. As a result, even if an mlir file was containing several occurrences of `vector.contract`, all of them would be lowered in the same way. More granularity might be required . This Diff adds a `constraint` argument to each of these patterns which allows the user to specify with more precision on which `vector.contract` should each of the lowering apply.

Differential Revision: https://reviews.llvm.org/D83960
This commit is contained in:
Pierre Oechsel
2020-07-17 12:02:11 -04:00
committed by Nicolas Vasilache
parent 1afd889d0b
commit ec62e37c86
4 changed files with 91 additions and 6 deletions

View File

@@ -1581,6 +1581,9 @@ ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const {
vector::VectorContractLowering::Matmul)
return failure();
if (failed(filter(op)))
return failure();
auto iteratorTypes = op.iterator_types().getValue();
if (!isParallelIterator(iteratorTypes[0]) ||
!isParallelIterator(iteratorTypes[1]) ||
@@ -1647,6 +1650,9 @@ ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const {
vector::VectorContractLowering::OuterProduct)
return failure();
if (failed(filter(op)))
return failure();
// Determine if the parallel/reduction structure matches something
// that can be expressed a reduction_size unrolled sequence.
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
@@ -1808,6 +1814,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO: implement masks.
if (llvm::size(op.masks()) != 0)
return failure();
if (failed(filter(op)))
return failure();
// TODO: support mixed mode contract lowering.
if (op.getLhsType().getElementType() !=
getElementTypeOrSelf(op.getAccType()) ||