diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 887d1af76454..f53bb5157eb3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -188,7 +188,9 @@ struct MaskOpRewritePattern : OpRewritePattern { private: LogicalResult matchAndRewrite(MaskOp maskOp, PatternRewriter &rewriter) const final { - auto maskableOp = cast(maskOp.getMaskableOp()); + auto maskableOp = cast_or_null(maskOp.getMaskableOp()); + if (!maskableOp) + return failure(); SourceOp sourceOp = dyn_cast(maskableOp.getOperation()); if (!sourceOp) return failure(); @@ -282,6 +284,7 @@ struct LowerVectorMaskPass RewritePatternSet loweringPatterns(context); populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns); + MaskOp::getCanonicalizationPatterns(loweringPatterns, context); if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) signalPassFailure(); diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir index 8f8fae095cac..a8a1164e2f76 100644 --- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir +++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir @@ -77,3 +77,14 @@ func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor< // CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> // CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> +// ----- + +// CHECK-LABEL: func @empty_vector_mask_with_return +// CHECK-SAME: %[[IN:.*]]: vector<8xf32> +func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> { +// CHECK-NOT: vector.mask +// CHECK: return %[[IN]] : vector<8xf32> + %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32> + return %0 : vector<8xf32> +} +