mirror of
https://github.com/intel/llvm.git
synced 2026-01-15 12:25:46 +08:00
[mlir][armsme][vector] Replace splat with broadcast (#148024)
Part of deprecation of vector.splat RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
This commit is contained in:
@@ -607,7 +607,8 @@ struct InsertTileSliceConversion
|
||||
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
|
||||
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
|
||||
/*scalableDims=*/{true});
|
||||
auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one);
|
||||
auto allActiveMask =
|
||||
vector::BroadcastOp::create(rewriter, loc, predTy, one);
|
||||
|
||||
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
|
||||
switch (insertTileSliceOp.getLayout()) {
|
||||
|
||||
@@ -327,7 +327,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
|
||||
|
||||
// Splat pad into 1-D vector matching type of tile slice.
|
||||
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
|
||||
auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp);
|
||||
auto pad1DOp =
|
||||
vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
|
||||
|
||||
auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
|
||||
tileLoadOp.getBase(),
|
||||
|
||||
@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for vector.splat.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
|
||||
///
|
||||
/// is converted to:
|
||||
///
|
||||
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
|
||||
/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
|
||||
/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
|
||||
/// {
|
||||
/// %tile_update = arm_sme.insert_tile_slice
|
||||
/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
|
||||
/// vector<[4]xi32> into vector<[4]x[4]xi32>
|
||||
/// scf.yield %tile_update : vector<[4]x[4]xi32>
|
||||
/// }
|
||||
///
|
||||
/// This is identical to vector.broadcast of a scalar.
|
||||
struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
|
||||
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto tileType = splatOp.getResult().getType();
|
||||
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
|
||||
return failure();
|
||||
|
||||
auto loc = splatOp.getLoc();
|
||||
auto srcType = splatOp.getOperand().getType();
|
||||
|
||||
assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
|
||||
// Avoid unused-variable warning when building without assertions.
|
||||
(void)srcType;
|
||||
|
||||
// First, broadcast the scalar to a 1-d vector.
|
||||
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
|
||||
Value broadcastOp1D = vector::BroadcastOp::create(
|
||||
rewriter, loc, tileSliceType, splatOp.getInput());
|
||||
|
||||
auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
|
||||
|
||||
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
|
||||
Value currentTile) {
|
||||
auto nextTile = arm_sme::InsertTileSliceOp::create(
|
||||
b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
|
||||
return nextTile.getResult();
|
||||
};
|
||||
|
||||
// Next, create a loop over ZA tile slices and "move" the generated 1-d
|
||||
// vector to each slice.
|
||||
auto forOp =
|
||||
createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
|
||||
|
||||
rewriter.replaceOp(splatOp, forOp.getResult(0));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for vector.transpose.
|
||||
///
|
||||
/// Stores the input tile to memory and reloads vertically.
|
||||
@@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering
|
||||
}
|
||||
};
|
||||
|
||||
// Convert all `vector.splat` to `vector.broadcast`. There is a path from
|
||||
// `vector.broadcast` to ArmSME via another pattern.
|
||||
struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
|
||||
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
|
||||
splatOp.getInput());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext &ctx) {
|
||||
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
|
||||
patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
|
||||
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
|
||||
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
|
||||
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
|
||||
|
||||
@@ -87,7 +87,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
|
||||
// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
|
||||
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
|
||||
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
|
||||
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
|
||||
// CHECK: %[[PAD_1D:.*]] = vector.broadcast %[[PAD]] : i32 to vector<[4]xi32>
|
||||
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
|
||||
// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
|
||||
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
|
||||
|
||||
Reference in New Issue
Block a user