mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[mlir][vector] Generalize folding of ext-contractionOp to other types. (#96593)
Many state of the art models and quantization operations are now directly working on vector.contract on integers. This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines. Signed-off-by: Stanley Winata <stanley.winata@amd.com>
This commit is contained in:
@@ -1552,6 +1552,7 @@ private:
|
||||
/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
|
||||
/// This pattern folds the arithmetic extensions into the vector contraction and
|
||||
/// enables the usage of native mixed precision Tensor Core instructions.
|
||||
template <typename ExtOp>
|
||||
struct FoldArithExtIntoContractionOp
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
@@ -1559,8 +1560,8 @@ struct FoldArithExtIntoContractionOp
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
|
||||
auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
|
||||
auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
|
||||
auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
|
||||
|
||||
if (!lhsDefOp || !rhsDefOp) {
|
||||
return rewriter.notifyMatchFailure(contractOp,
|
||||
@@ -1895,7 +1896,9 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
|
||||
|
||||
void mlir::vector::populateFoldArithExtensionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
|
||||
patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
|
||||
FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorMaskMaterializationPatterns(
|
||||
|
||||
@@ -48,3 +48,25 @@ func.func @fold_arith_extf_into_contract_scalable(
|
||||
%lhs_f32, %rhs_f32, %arg2 : vector<[64]x64xf32>, vector<64x64xf32> into vector<[64]x64xf32>
|
||||
return %result : vector<[64]x64xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @fold_arith_extsi_into_contract
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xi8>, %[[ARG1:.*]]: vector<64x64xi8>, %[[ARG2:.*]]: vector<64x64xi32>)
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xi8>, vector<64x64xi8> into vector<64x64xi32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<64x64xi32>
|
||||
func.func @fold_arith_extsi_into_contract(
|
||||
%arg0: vector<64x64xi8>,
|
||||
%arg1: vector<64x64xi8>,
|
||||
%arg2: vector<64x64xi32>) -> vector<64x64xi32> {
|
||||
%lhs_i32 = arith.extsi %arg0 : vector<64x64xi8> to vector<64x64xi32>
|
||||
%rhs_i32 = arith.extsi %arg1 : vector<64x64xi8> to vector<64x64xi32>
|
||||
%result = vector.contract {
|
||||
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>}
|
||||
%lhs_i32, %rhs_i32, %arg2 : vector<64x64xi32>, vector<64x64xi32> into vector<64x64xi32>
|
||||
return %result : vector<64x64xi32>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user