From 71ee84acc4f7c93b9292af90ef5d79dd05687410 Mon Sep 17 00:00:00 2001 From: Nishant Patel Date: Thu, 11 Dec 2025 13:16:55 -0800 Subject: [PATCH] [MLIR][Vector] Add unroll pattern for vector.constant_mask (#171518) This PR adds unrolling for vector.constant_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as: min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]). --- .../mlir/Dialect/Vector/IR/VectorOps.td | 4 +- .../Vector/Transforms/VectorUnroll.cpp | 91 ++++++++++++++++++- .../Dialect/Vector/vector-unroll-options.mlir | 17 ++++ .../Dialect/Vector/TestVectorTransforms.cpp | 12 ++- 4 files changed, 116 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index d10bedef6040..ddb04b6bbe40 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2534,7 +2534,9 @@ def Vector_TypeCastOp : } def Vector_ConstantMaskOp : - Vector_Op<"constant_mask", [Pure]>, + Vector_Op<"constant_mask", [Pure, + DeclareOpInterfaceMethods + ]>, Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>, Results<(outs VectorOfAnyRankOf<[I1]>)> { let summary = "creates a constant vector mask"; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 462bd8c3dc4a..b62ce8a2ec39 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1094,6 +1094,93 @@ private: vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.constant_mask` operations into smaller mask +/// operations based on the target unroll shape. Each unrolled slice computes +/// whether its elements should be masked based on the original mask dimensions +/// and the slice's offset position. +/// +/// Example: +/// Given a constant_mask operation: +/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1> +/// +/// and a target unroll shape of <4x8>, the pattern produces: +/// +/// %false = arith.constant dense : vector<8x16xi1> +/// +/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds +/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1> +/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// +/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds +/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1> +/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// +/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds +/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1> +/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// +/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds +/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1> +/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +struct UnrollConstantMaskPattern + : public OpRewritePattern { + UnrollConstantMaskPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp, + PatternRewriter &rewriter) const override { + std::optional> targetShape = + getTargetShape(options, constantMaskOp); + if (!targetShape) + return failure(); + + VectorType resultType = constantMaskOp.getVectorType(); + SmallVector originalSize = *constantMaskOp.getShapeForUnroll(); + Location loc = constantMaskOp.getLoc(); + + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + VectorType targetVectorType = + VectorType::get(*targetShape, rewriter.getI1Type()); + SmallVector strides(targetShape->size(), 1); + + // In each dimension (d), each unrolled vector computes its mask size as: + // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]). + for (const SmallVector &offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { + SmallVector unrolledMaskDims; + + for (auto [i, originalMaskDim] : + llvm::enumerate(constantMaskOp.getMaskDimSizes())) { + // Calculate how many elements in this dimension should be masked + // for this particular slice + int64_t adjustedMaskSize = + std::max(originalMaskDim - offsets[i], static_cast(0)); + int64_t unrolledMaskDim = + std::min(adjustedMaskSize, static_cast((*targetShape)[i])); + unrolledMaskDims.push_back(unrolledMaskDim); + } + + auto unrolledMask = rewriter.createOrFold( + loc, targetVectorType, unrolledMaskDims); + result = rewriter.createOrFold( + loc, unrolledMask, result, offsets, strides); + } + rewriter.replaceOp(constantMaskOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + /// Checks whether extractShape is a contiguous slice of shape. /// For extractShape to be contiguous in shape: /// 1) All but the leading dimension of extractShape and shape must match @@ -1294,8 +1381,8 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern, - UnrollCreateMaskPattern>(patterns.getContext(), options, - benefit); + UnrollCreateMaskPattern, UnrollConstantMaskPattern>( + patterns.getContext(), options, benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 805e66f133c5..c2e7f6a9338b 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -552,6 +552,23 @@ func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> { // CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> // CHECK: return %[[S3]] : vector<16x16xi1> +func.func @vector_constant_mask() -> vector<16x16xi1> { + %0 = vector.constant_mask [12, 10] : vector<16x16xi1> + return %0 : vector<16x16xi1> +} + +// CHECK-LABEL: func @vector_constant_mask +// CHECK-SAME: () -> vector<16x16xi1> +// CHECK: %[[CST:.*]] = arith.constant dense : vector<16x16xi1> +// CHECK: %[[CST_TRUE:.*]] = arith.constant dense : vector<8x8xi1> +// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[CST_TRUE]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[MASK01:.*]] = vector.constant_mask [8, 2] : vector<8x8xi1> +// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[MASK10:.*]] = vector.constant_mask [4, 8] : vector<8x8xi1> +// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[MASK11:.*]] = vector.constant_mask [4, 2] : vector<8x8xi1> +// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: return %[[INS11]] : vector<16x16xi1> func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> { %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f834d0cdd42b..2cbb5ab3067f 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -179,11 +179,13 @@ struct TestVectorUnrollingPatterns return success(isa(op)); })); populateVectorUnrollPatterns( - patterns, UnrollVectorOptions() - .setNativeShape(ArrayRef{8, 8}) - .setFilterConstraint([](Operation *op) { - return success(isa(op)); - })); + patterns, + UnrollVectorOptions() + .setNativeShape(ArrayRef{8, 8}) + .setFilterConstraint([](Operation *op) { + return success( + isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions()