[MLIR][Vector] Add unroll pattern for vector.constant_mask (#171518)

This PR adds unrolling for vector.constant_mask op based on the
targetShape. Each unrolled vector computes its local mask size in each
dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).
This commit is contained in:
Nishant Patel
2025-12-11 13:16:55 -08:00
committed by GitHub
parent 757c5b3bc7
commit 71ee84acc4
4 changed files with 116 additions and 8 deletions

View File

@@ -2534,7 +2534,9 @@ def Vector_TypeCastOp :
}
def Vector_ConstantMaskOp :
Vector_Op<"constant_mask", [Pure]>,
Vector_Op<"constant_mask", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
]>,
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";

View File

@@ -1094,6 +1094,93 @@ private:
vector::UnrollVectorOptions options;
};
/// This pattern unrolls `vector.constant_mask` operations into smaller mask
/// operations based on the target unroll shape. Each unrolled slice computes
/// whether its elements should be masked based on the original mask dimensions
/// and the slice's offset position.
///
/// Example:
/// Given a constant_mask operation:
/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
///
/// and a target unroll shape of <4x8>, the pattern produces:
///
/// %false = arith.constant dense<false> : vector<8x16xi1>
///
/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
///
/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
///
/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
///
/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
/// : vector<4x8xi1> into vector<8x16xi1>
struct UnrollConstantMaskPattern
: public OpRewritePattern<vector::ConstantMaskOp> {
UnrollConstantMaskPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, constantMaskOp);
if (!targetShape)
return failure();
VectorType resultType = constantMaskOp.getVectorType();
SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
Location loc = constantMaskOp.getLoc();
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));
VectorType targetVectorType =
VectorType::get(*targetShape, rewriter.getI1Type());
SmallVector<int64_t> strides(targetShape->size(), 1);
// In each dimension (d), each unrolled vector computes its mask size as:
// min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
for (const SmallVector<int64_t> &offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> unrolledMaskDims;
for (auto [i, originalMaskDim] :
llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
// Calculate how many elements in this dimension should be masked
// for this particular slice
int64_t adjustedMaskSize =
std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
int64_t unrolledMaskDim =
std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
unrolledMaskDims.push_back(unrolledMaskDim);
}
auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
loc, targetVectorType, unrolledMaskDims);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, unrolledMask, result, offsets, strides);
}
rewriter.replaceOp(constantMaskOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
/// Checks whether extractShape is a contiguous slice of shape.
/// For extractShape to be contiguous in shape:
/// 1) All but the leading dimension of extractShape and shape must match
@@ -1294,8 +1381,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
UnrollCreateMaskPattern>(patterns.getContext(), options,
benefit);
UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(

View File

@@ -552,6 +552,23 @@ func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: return %[[S3]] : vector<16x16xi1>
func.func @vector_constant_mask() -> vector<16x16xi1> {
%0 = vector.constant_mask [12, 10] : vector<16x16xi1>
return %0 : vector<16x16xi1>
}
// CHECK-LABEL: func @vector_constant_mask
// CHECK-SAME: () -> vector<16x16xi1>
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
// CHECK: %[[CST_TRUE:.*]] = arith.constant dense<true> : vector<8x8xi1>
// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[CST_TRUE]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: %[[MASK01:.*]] = vector.constant_mask [8, 2] : vector<8x8xi1>
// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: %[[MASK10:.*]] = vector.constant_mask [4, 8] : vector<8x8xi1>
// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: %[[MASK11:.*]] = vector.constant_mask [4, 2] : vector<8x8xi1>
// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
// CHECK: return %[[INS11]] : vector<16x16xi1>
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>

View File

@@ -179,11 +179,13 @@ struct TestVectorUnrollingPatterns
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{8, 8})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::CreateMaskOp>(op));
}));
patterns,
UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{8, 8})
.setFilterConstraint([](Operation *op) {
return success(
isa<vector::CreateMaskOp, vector::ConstantMaskOp>(op));
}));
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions()