mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 19:44:38 +08:00
[MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern (#92426)
Implements `TransferOpReduceRank` as a `MaskableOpRewritePattern`. Allowing to exit gracefully when run on a `vector::transfer_read` located inside a `vector::MaskOp` instead of generating `error: 'vector.mask' op expects only one operation to mask` because the pattern generated multiple ops inside the MaskOp. Split of https://github.com/llvm/llvm-project/pull/90835
This commit is contained in:
@@ -322,14 +322,20 @@ struct TransferWriteNonPermutationLowering
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
|
||||
/// vector.broadcast %v
|
||||
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
struct TransferOpReduceRank
|
||||
: public MaskableOpRewritePattern<vector::TransferReadOp> {
|
||||
using MaskableOpRewritePattern::MaskableOpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
FailureOr<mlir::Value>
|
||||
matchAndRewriteMaskableOp(vector::TransferReadOp op,
|
||||
MaskingOpInterface maskOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO: support 0-d corner case.
|
||||
if (op.getTransferRank() == 0)
|
||||
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
|
||||
// TODO: support masked case.
|
||||
if (maskOp)
|
||||
return rewriter.notifyMatchFailure(op, "Masked case not supported");
|
||||
|
||||
AffineMap map = op.getPermutationMap();
|
||||
unsigned numLeadingBroadcast = 0;
|
||||
@@ -369,9 +375,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
|
||||
op.getLoc(), originalVecType.getElementType(), op.getSource(),
|
||||
op.getIndices());
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
|
||||
newRead);
|
||||
return success();
|
||||
return rewriter
|
||||
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
|
||||
.getVector();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> newShape(
|
||||
@@ -393,9 +399,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
|
||||
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
|
||||
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
|
||||
newInBoundsAttr);
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
|
||||
newRead);
|
||||
return success();
|
||||
return rewriter
|
||||
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
|
||||
.getVector();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -187,3 +187,49 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
|
||||
// CHECK: func.func @transfer_read_reduce_rank_scalable(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
|
||||
// CHECK: %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
|
||||
// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
|
||||
func.func @transfer_read_reduce_rank_scalable(%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_0 = arith.constant 0.000000e+00 : f32
|
||||
%1 = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
|
||||
{in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
|
||||
: memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
|
||||
return %1 : vector<8x[4]x2x3xf32>
|
||||
}
|
||||
|
||||
// Masked case not supported.
|
||||
// CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>,
|
||||
// CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
|
||||
// CHECK-NOT: vector.broadcast
|
||||
// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %arg0{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
|
||||
func.func @masked_transfer_read_reduce_rank(%mem: memref<?x?x?x?xf32>, %dim: index) -> vector<8x[4]x2x3xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_0 = arith.constant 0.000000e+00 : f32
|
||||
%mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
|
||||
%res = vector.mask %mask { vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
|
||||
{in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
|
||||
: memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
|
||||
return %res : vector<8x[4]x2x3xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
|
||||
%f = transform.structured.match ops{["func.func"]} in %module_op
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %f {
|
||||
transform.apply_patterns.vector.transfer_permutation_patterns
|
||||
} : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user