[mlir][armsme][vector] Replace splat with broadcast (#148024)

Part of deprecation of vector.splat
RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
This commit is contained in:
James Newling
2025-07-23 13:18:09 -04:00
committed by GitHub
parent 8ef0c50eca
commit e67f3237d6
4 changed files with 20 additions and 64 deletions

View File

@@ -607,7 +607,8 @@ struct InsertTileSliceConversion
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one);
auto allActiveMask =
vector::BroadcastOp::create(rewriter, loc, predTy, one);
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {

View File

@@ -327,7 +327,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp);
auto pad1DOp =
vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
tileLoadOp.getBase(),

View File

@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
}
};
/// Conversion pattern for vector.splat.
///
/// Example:
///
/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
///
/// is converted to:
///
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
/// {
/// %tile_update = arm_sme.insert_tile_slice
/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
/// vector<[4]xi32> into vector<[4]x[4]xi32>
/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
///
/// This is identical to vector.broadcast of a scalar.
struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
PatternRewriter &rewriter) const final {
auto tileType = splatOp.getResult().getType();
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
auto loc = splatOp.getLoc();
auto srcType = splatOp.getOperand().getType();
assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
// Avoid unused-variable warning when building without assertions.
(void)srcType;
// First, broadcast the scalar to a 1-d vector.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
Value broadcastOp1D = vector::BroadcastOp::create(
rewriter, loc, tileSliceType, splatOp.getInput());
auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
auto nextTile = arm_sme::InsertTileSliceOp::create(
b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
auto forOp =
createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
rewriter.replaceOp(splatOp, forOp.getResult(0));
return success();
}
};
/// Conversion pattern for vector.transpose.
///
/// Stores the input tile to memory and reloads vertically.
@@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering
}
};
// Convert all `vector.splat` to `vector.broadcast`. There is a path from
// `vector.broadcast` to ArmSME via another pattern.
struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
PatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
splatOp.getInput());
return success();
}
};
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,

View File

@@ -87,7 +87,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
// CHECK: %[[PAD_1D:.*]] = vector.broadcast %[[PAD]] : i32 to vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>