mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[MLIR][Vector] Add unroll pattern for vector.constant_mask (#171518)
This PR adds unrolling for vector.constant_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as: min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).
This commit is contained in:
@@ -2534,7 +2534,9 @@ def Vector_TypeCastOp :
|
||||
}
|
||||
|
||||
def Vector_ConstantMaskOp :
|
||||
Vector_Op<"constant_mask", [Pure]>,
|
||||
Vector_Op<"constant_mask", [Pure,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
|
||||
]>,
|
||||
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
|
||||
Results<(outs VectorOfAnyRankOf<[I1]>)> {
|
||||
let summary = "creates a constant vector mask";
|
||||
|
||||
@@ -1094,6 +1094,93 @@ private:
|
||||
vector::UnrollVectorOptions options;
|
||||
};
|
||||
|
||||
/// This pattern unrolls `vector.constant_mask` operations into smaller mask
|
||||
/// operations based on the target unroll shape. Each unrolled slice computes
|
||||
/// whether its elements should be masked based on the original mask dimensions
|
||||
/// and the slice's offset position.
|
||||
///
|
||||
/// Example:
|
||||
/// Given a constant_mask operation:
|
||||
/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
|
||||
///
|
||||
/// and a target unroll shape of <4x8>, the pattern produces:
|
||||
///
|
||||
/// %false = arith.constant dense<false> : vector<8x16xi1>
|
||||
///
|
||||
/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
|
||||
/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
|
||||
/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
|
||||
/// : vector<4x8xi1> into vector<8x16xi1>
|
||||
///
|
||||
/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
|
||||
/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
|
||||
/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
|
||||
/// : vector<4x8xi1> into vector<8x16xi1>
|
||||
///
|
||||
/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
|
||||
/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
|
||||
/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
|
||||
/// : vector<4x8xi1> into vector<8x16xi1>
|
||||
///
|
||||
/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
|
||||
/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
|
||||
/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
|
||||
/// : vector<4x8xi1> into vector<8x16xi1>
|
||||
struct UnrollConstantMaskPattern
|
||||
: public OpRewritePattern<vector::ConstantMaskOp> {
|
||||
UnrollConstantMaskPattern(MLIRContext *context,
|
||||
const vector::UnrollVectorOptions &options,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
|
||||
options(options) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
std::optional<SmallVector<int64_t>> targetShape =
|
||||
getTargetShape(options, constantMaskOp);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
|
||||
VectorType resultType = constantMaskOp.getVectorType();
|
||||
SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
|
||||
Location loc = constantMaskOp.getLoc();
|
||||
|
||||
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
|
||||
rewriter.getZeroAttr(resultType));
|
||||
VectorType targetVectorType =
|
||||
VectorType::get(*targetShape, rewriter.getI1Type());
|
||||
SmallVector<int64_t> strides(targetShape->size(), 1);
|
||||
|
||||
// In each dimension (d), each unrolled vector computes its mask size as:
|
||||
// min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
|
||||
for (const SmallVector<int64_t> &offsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape)) {
|
||||
SmallVector<int64_t> unrolledMaskDims;
|
||||
|
||||
for (auto [i, originalMaskDim] :
|
||||
llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
|
||||
// Calculate how many elements in this dimension should be masked
|
||||
// for this particular slice
|
||||
int64_t adjustedMaskSize =
|
||||
std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
|
||||
int64_t unrolledMaskDim =
|
||||
std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
|
||||
unrolledMaskDims.push_back(unrolledMaskDim);
|
||||
}
|
||||
|
||||
auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
|
||||
loc, targetVectorType, unrolledMaskDims);
|
||||
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
|
||||
loc, unrolledMask, result, offsets, strides);
|
||||
}
|
||||
rewriter.replaceOp(constantMaskOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
vector::UnrollVectorOptions options;
|
||||
};
|
||||
|
||||
/// Checks whether extractShape is a contiguous slice of shape.
|
||||
/// For extractShape to be contiguous in shape:
|
||||
/// 1) All but the leading dimension of extractShape and shape must match
|
||||
@@ -1294,8 +1381,8 @@ void mlir::vector::populateVectorUnrollPatterns(
|
||||
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
|
||||
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
|
||||
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
|
||||
UnrollCreateMaskPattern>(patterns.getContext(), options,
|
||||
benefit);
|
||||
UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
|
||||
patterns.getContext(), options, benefit);
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorToElementsUnrollPatterns(
|
||||
|
||||
@@ -552,6 +552,23 @@ func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
|
||||
// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
|
||||
// CHECK: return %[[S3]] : vector<16x16xi1>
|
||||
|
||||
func.func @vector_constant_mask() -> vector<16x16xi1> {
|
||||
%0 = vector.constant_mask [12, 10] : vector<16x16xi1>
|
||||
return %0 : vector<16x16xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_constant_mask
|
||||
// CHECK-SAME: () -> vector<16x16xi1>
|
||||
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
|
||||
// CHECK: %[[CST_TRUE:.*]] = arith.constant dense<true> : vector<8x8xi1>
|
||||
// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[CST_TRUE]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
|
||||
// CHECK: %[[MASK01:.*]] = vector.constant_mask [8, 2] : vector<8x8xi1>
|
||||
// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
|
||||
// CHECK: %[[MASK10:.*]] = vector.constant_mask [4, 8] : vector<8x8xi1>
|
||||
// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
|
||||
// CHECK: %[[MASK11:.*]] = vector.constant_mask [4, 2] : vector<8x8xi1>
|
||||
// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
|
||||
// CHECK: return %[[INS11]] : vector<16x16xi1>
|
||||
|
||||
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
|
||||
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
|
||||
|
||||
@@ -179,11 +179,13 @@ struct TestVectorUnrollingPatterns
|
||||
return success(isa<vector::StepOp>(op));
|
||||
}));
|
||||
populateVectorUnrollPatterns(
|
||||
patterns, UnrollVectorOptions()
|
||||
.setNativeShape(ArrayRef<int64_t>{8, 8})
|
||||
.setFilterConstraint([](Operation *op) {
|
||||
return success(isa<vector::CreateMaskOp>(op));
|
||||
}));
|
||||
patterns,
|
||||
UnrollVectorOptions()
|
||||
.setNativeShape(ArrayRef<int64_t>{8, 8})
|
||||
.setFilterConstraint([](Operation *op) {
|
||||
return success(
|
||||
isa<vector::CreateMaskOp, vector::ConstantMaskOp>(op));
|
||||
}));
|
||||
populateVectorUnrollPatterns(
|
||||
patterns,
|
||||
UnrollVectorOptions()
|
||||
|
||||
Reference in New Issue
Block a user