[mlir][vector] Handle empty MaskOp in LowerVectorMask, MaskOpRewritePattern (#72031)

This patch adds handling of an empty `MaskOp` to `MaskOpRewritePattern`
and thereby fixes a crash.
It also pulls the `MaskOp` canonicalization patterns into
`LowerVectorMask` so that empty `MaskOp`s are folded away in the Pass.

Fix https://github.com/llvm/llvm-project/issues/71036
This commit is contained in:
Felix Schneider
2023-11-12 08:12:28 +01:00
committed by GitHub
parent fcb160eabc
commit d5a0fb39ae
2 changed files with 15 additions and 1 deletions

View File

@@ -188,7 +188,9 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const final {
auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
if (!maskableOp)
return failure();
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
if (!sourceOp)
return failure();
@@ -282,6 +284,7 @@ struct LowerVectorMaskPass
RewritePatternSet loweringPatterns(context);
populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
signalPassFailure();

View File

@@ -77,3 +77,14 @@ func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<
// CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
// -----
// CHECK-LABEL: func @empty_vector_mask_with_return
// CHECK-SAME: %[[IN:.*]]: vector<8xf32>
func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
// CHECK-NOT: vector.mask
// CHECK: return %[[IN]] : vector<8xf32>
%0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
return %0 : vector<8xf32>
}