mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
[mlir][vector] Handle empty MaskOp in LowerVectorMask, MaskOpRewritePattern (#72031)
This patch adds handling of an empty `MaskOp` to `MaskOpRewritePattern` and thereby fixes a crash. It also pulls the `MaskOp` canonicalization patterns into `LowerVectorMask` so that empty `MaskOp`s are folded away in the Pass. Fix https://github.com/llvm/llvm-project/issues/71036
This commit is contained in:
@@ -188,7 +188,9 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
|
||||
private:
|
||||
LogicalResult matchAndRewrite(MaskOp maskOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
|
||||
auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
|
||||
if (!maskableOp)
|
||||
return failure();
|
||||
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
|
||||
if (!sourceOp)
|
||||
return failure();
|
||||
@@ -282,6 +284,7 @@ struct LowerVectorMaskPass
|
||||
|
||||
RewritePatternSet loweringPatterns(context);
|
||||
populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
|
||||
MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
|
||||
signalPassFailure();
|
||||
|
||||
@@ -77,3 +77,14 @@ func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<
|
||||
// CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @empty_vector_mask_with_return
|
||||
// CHECK-SAME: %[[IN:.*]]: vector<8xf32>
|
||||
func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
|
||||
// CHECK-NOT: vector.mask
|
||||
// CHECK: return %[[IN]] : vector<8xf32>
|
||||
%0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user