diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index eac6db585aad..da3d9648cf28 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1552,6 +1552,7 @@ private: /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`. /// This pattern folds the arithmetic extensions into the vector contraction and /// enables the usage of native mixed precision Tensor Core instructions. +template struct FoldArithExtIntoContractionOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1559,8 +1560,8 @@ struct FoldArithExtIntoContractionOp LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { - auto lhsDefOp = contractOp.getLhs().getDefiningOp(); - auto rhsDefOp = contractOp.getRhs().getDefiningOp(); + auto lhsDefOp = contractOp.getLhs().getDefiningOp(); + auto rhsDefOp = contractOp.getRhs().getDefiningOp(); if (!lhsDefOp || !rhsDefOp) { return rewriter.notifyMatchFailure(contractOp, @@ -1895,7 +1896,9 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern { void mlir::vector::populateFoldArithExtensionPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add, + FoldArithExtIntoContractionOp>( + patterns.getContext()); } void mlir::vector::populateVectorMaskMaterializationPatterns( diff --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir index 31ae126906f2..6dbde7afbdd3 100644 --- a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir +++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir @@ -48,3 +48,25 @@ func.func @fold_arith_extf_into_contract_scalable( %lhs_f32, %rhs_f32, %arg2 : vector<[64]x64xf32>, vector<64x64xf32> into vector<[64]x64xf32> return %result : vector<[64]x64xf32> } + +// ----- + +// CHECK-LABEL: func.func @fold_arith_extsi_into_contract +// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xi8>, %[[ARG1:.*]]: vector<64x64xi8>, %[[ARG2:.*]]: vector<64x64xi32>) +// CHECK-NEXT: %[[R:.+]] = vector.contract +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xi8>, vector<64x64xi8> into vector<64x64xi32> +// CHECK-NEXT: return %[[R]] : vector<64x64xi32> +func.func @fold_arith_extsi_into_contract( + %arg0: vector<64x64xi8>, + %arg1: vector<64x64xi8>, + %arg2: vector<64x64xi32>) -> vector<64x64xi32> { + %lhs_i32 = arith.extsi %arg0 : vector<64x64xi8> to vector<64x64xi32> + %rhs_i32 = arith.extsi %arg1 : vector<64x64xi8> to vector<64x64xi32> + %result = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %lhs_i32, %rhs_i32, %arg2 : vector<64x64xi32>, vector<64x64xi32> into vector<64x64xi32> + return %result : vector<64x64xi32> +}