mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[mlir][ArmSME] Lower extract from 2D scalable create_mask to psel (#96066)
Example:
```mlir
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%slice = vector.extract %mask[%index]
: vector<[8]xi1> from vector<[4]x[8]xi1>
```
Becomes:
```mlir
%mask_rows = vector.create_mask %a : vector<[4]xi1>
%mask_cols = vector.create_mask %b : vector<[8]xi1>
%slice = arm_sve.psel %mask_cols, %mask_rows[%index]
: vector<[8]xi1>, vector<[4]xi1>
```
Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this
is currently the most logical place for this lowering.
This commit is contained in:
@@ -1276,7 +1276,7 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
|
||||
Pass that converts vector dialect operations into equivalent ArmSME dialect
|
||||
operations.
|
||||
}];
|
||||
let dependentDialects = ["arm_sme::ArmSMEDialect"];
|
||||
let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArmSMEDialect
|
||||
MLIRArmSVEDialect
|
||||
MLIRLLVMCommonConversion
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
|
||||
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
|
||||
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
@@ -719,16 +720,86 @@ struct FoldTransferWriteOfExtractTileSlice
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
|
||||
/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
|
||||
/// SVE 2.1), so this is currently the most logical place for this lowering.
|
||||
///
|
||||
/// Example:
|
||||
/// ```mlir
|
||||
/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
|
||||
/// %slice = vector.extract %mask[%index]
|
||||
/// : vector<[8]xi1> from vector<[4]x[8]xi1>
|
||||
/// ```
|
||||
/// Becomes:
|
||||
/// ```
|
||||
/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
|
||||
/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
|
||||
/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
|
||||
/// : vector<[8]xi1>, vector<[4]xi1>
|
||||
/// ```
|
||||
struct ExtractFromCreateMaskToPselLowering
|
||||
: public OpRewritePattern<vector::ExtractOp> {
|
||||
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (extractOp.getNumIndices() != 1)
|
||||
return rewriter.notifyMatchFailure(extractOp, "not single extract index");
|
||||
|
||||
auto resultType = extractOp.getResult().getType();
|
||||
auto resultVectorType = dyn_cast<VectorType>(resultType);
|
||||
if (!resultVectorType)
|
||||
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
|
||||
|
||||
auto createMaskOp =
|
||||
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
|
||||
if (!createMaskOp)
|
||||
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
|
||||
|
||||
auto maskType = createMaskOp.getVectorType();
|
||||
if (maskType.getRank() != 2 || !maskType.allDimsScalable())
|
||||
return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
|
||||
|
||||
auto isSVEPredicateSize = [](int64_t size) {
|
||||
return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
|
||||
};
|
||||
|
||||
auto rowsBaseSize = maskType.getDimSize(0);
|
||||
auto colsBaseSize = maskType.getDimSize(1);
|
||||
if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
|
||||
return rewriter.notifyMatchFailure(
|
||||
createMaskOp, "mask dimensions not SVE predicate-sized");
|
||||
|
||||
auto loc = extractOp.getLoc();
|
||||
VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
|
||||
VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
|
||||
|
||||
// Create the two 1-D masks at the location of the 2-D create_mask (which is
|
||||
// usually outside a loop). This prevents the need for later hoisting.
|
||||
rewriter.setInsertionPoint(createMaskOp);
|
||||
auto rowMask = rewriter.create<vector::CreateMaskOp>(
|
||||
loc, rowMaskType, createMaskOp.getOperand(0));
|
||||
auto colMask = rewriter.create<vector::CreateMaskOp>(
|
||||
loc, colMaskType, createMaskOp.getOperand(1));
|
||||
|
||||
rewriter.setInsertionPoint(extractOp);
|
||||
auto position =
|
||||
vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
|
||||
rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
|
||||
position[0]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext &ctx) {
|
||||
patterns
|
||||
.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
|
||||
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
|
||||
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
|
||||
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
|
||||
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
|
||||
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
|
||||
&ctx);
|
||||
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
|
||||
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
|
||||
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
|
||||
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
|
||||
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
|
||||
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
|
||||
ExtractFromCreateMaskToPselLowering>(&ctx);
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
|
||||
|
||||
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
|
||||
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
|
||||
@@ -192,3 +192,54 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
|
||||
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
|
||||
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Not SVE predicate-sized.
|
||||
|
||||
// CHECK-LABEL: @negative_vector_extract_to_psel_0
|
||||
func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1>
|
||||
{
|
||||
// CHECK-NOT: arm_sve.psel
|
||||
%mask = vector.create_mask %a, %b : vector<[4]x[32]xi1>
|
||||
%slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1>
|
||||
return %slice : vector<[32]xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Source not 2-D scalable mask.
|
||||
|
||||
// CHECK-LABEL: @negative_vector_extract_to_psel_1
|
||||
func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1>
|
||||
{
|
||||
// CHECK-NOT: arm_sve.psel
|
||||
%mask = vector.create_mask %a, %b : vector<4x[8]xi1>
|
||||
%slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1>
|
||||
return %slice : vector<[8]xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Source not vector.create_mask.
|
||||
|
||||
// CHECK-LABEL: @negative_vector_extract_to_psel_2
|
||||
func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1>
|
||||
{
|
||||
// CHECK-NOT: arm_sve.psel
|
||||
%slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
|
||||
return %slice : vector<[8]xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/// Not psel-like extract.
|
||||
|
||||
// CHECK-LABEL: @negative_vector_extract_to_psel_3
|
||||
func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1
|
||||
{
|
||||
// CHECK-NOT: arm_sve.psel
|
||||
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
|
||||
%el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1>
|
||||
return %el : i1
|
||||
}
|
||||
|
||||
@@ -1124,7 +1124,7 @@ func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vect
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.extract
|
||||
// vector.extract --> arm_sme.move_tile_slice_to_vector
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// -----
|
||||
@@ -1320,3 +1320,37 @@ func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
|
||||
%el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
|
||||
return %el : f64
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.extract --> arm_sve.psel
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @dynamic_vector_extract_mask_to_psel(
|
||||
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[INDEX:.*]]: index)
|
||||
func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1>
|
||||
{
|
||||
// CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1>
|
||||
// CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1>
|
||||
// CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1>
|
||||
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
|
||||
%slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
|
||||
return %slice : vector<[8]xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @vector_extract_mask_to_psel(
|
||||
// CHECK-SAME: %[[A:.*]]: index,
|
||||
// CHECK-SAME: %[[B:.*]]: index)
|
||||
func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1>
|
||||
{
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[16]xi1>
|
||||
// CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[2]xi1>
|
||||
// CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[C1]]] : vector<[2]xi1>, vector<[16]xi1>
|
||||
%mask = vector.create_mask %a, %b : vector<[16]x[2]xi1>
|
||||
%slice = vector.extract %mask[1] : vector<[2]xi1> from vector<[16]x[2]xi1>
|
||||
return %slice : vector<[2]xi1>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user