mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[nlir][vector] Constrain ContractionOpToMatmulOpLowering (#102225)
Disables `ContractionOpToMatmulOpLowering` for scalable vectors. This pattern is meant to enable lowering to `llvm.matrix.multiply` - I'm not aware of any use of that in the context of scalable vectors.
This commit is contained in:
committed by
GitHub
parent
9dae7fcc92
commit
cb89457ff8
@@ -1283,6 +1283,8 @@ public:
|
||||
/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
|
||||
/// vector.transpose operations are inserted if the vector.contract op is not a
|
||||
/// row-major matrix multiply.
|
||||
///
|
||||
/// Scalable vectors are not supported.
|
||||
FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
|
||||
vector::ContractionOp op, MaskingOpInterface maskOp,
|
||||
PatternRewriter &rew) const {
|
||||
@@ -1302,13 +1304,18 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
|
||||
!isReductionIterator(iteratorTypes[2]))
|
||||
return failure();
|
||||
|
||||
Type opResType = op.getType();
|
||||
VectorType vecType = dyn_cast<VectorType>(opResType);
|
||||
if (vecType && vecType.isScalable()) {
|
||||
// Note - this is sufficient to reject all cases with scalable vectors.
|
||||
return failure();
|
||||
}
|
||||
|
||||
Type elementType = op.getLhsType().getElementType();
|
||||
if (!elementType.isIntOrFloat())
|
||||
return failure();
|
||||
|
||||
Type dstElementType = op.getType();
|
||||
if (auto vecType = dyn_cast<VectorType>(dstElementType))
|
||||
dstElementType = vecType.getElementType();
|
||||
Type dstElementType = vecType ? vecType.getElementType() : opResType;
|
||||
if (elementType != dstElementType)
|
||||
return failure();
|
||||
|
||||
|
||||
@@ -36,13 +36,23 @@
|
||||
// CHECK: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
|
||||
// CHECK: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32>
|
||||
func.func @matmul(%arg0: vector<2x4xf32>,
|
||||
%arg1: vector<4x3xf32>,
|
||||
%arg2: vector<2x3xf32>) -> vector<2x3xf32> {
|
||||
%arg1: vector<4x3xf32>,
|
||||
%arg2: vector<2x3xf32>) -> vector<2x3xf32> {
|
||||
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
|
||||
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
|
||||
return %0 : vector<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @matmul_scalable
|
||||
// CHECK-NOT: vector.matrix_multiply
|
||||
func.func @matmul_scalable(%arg0: vector<2x4xf32>,
|
||||
%arg1: vector<4x[3]xf32>,
|
||||
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
|
||||
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
|
||||
: vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
|
||||
return %0 : vector<2x[3]xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
|
||||
%f = transform.structured.match ops{["func.func"]} in %module_op
|
||||
|
||||
Reference in New Issue
Block a user