mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[mlir][vector] Use DenseI64ArrayAttr for shuffle masks (#101163)
Follow on from #100997. This again removes from boilerplate conversions to/from IntegerAttr and int64_t (otherwise, this is a NFC).
This commit is contained in:
@@ -421,7 +421,7 @@ def Vector_ShuffleOp :
|
||||
TCresVTEtIsSameAsOpBase<0, 1>>,
|
||||
InferTypeOpAdaptor]>,
|
||||
Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
|
||||
I64ArrayAttr:$mask)>,
|
||||
DenseI64ArrayAttr:$mask)>,
|
||||
Results<(outs AnyVector:$vector)> {
|
||||
let summary = "shuffle operation";
|
||||
let description = [{
|
||||
@@ -459,11 +459,7 @@ def Vector_ShuffleOp :
|
||||
: vector<f32>, vector<f32> ; yields vector<2xf32>
|
||||
```
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
|
||||
];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getV1VectorType() {
|
||||
return ::llvm::cast<VectorType>(getV1().getType());
|
||||
@@ -475,7 +471,10 @@ def Vector_ShuffleOp :
|
||||
return ::llvm::cast<VectorType>(getVector().getType());
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
@@ -994,7 +994,7 @@ public:
|
||||
auto v2Type = shuffleOp.getV2VectorType();
|
||||
auto vectorType = shuffleOp.getResultVectorType();
|
||||
Type llvmType = typeConverter->convertType(vectorType);
|
||||
auto maskArrayAttr = shuffleOp.getMask();
|
||||
ArrayRef<int64_t> mask = shuffleOp.getMask();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
@@ -1015,7 +1015,7 @@ public:
|
||||
if (rank <= 1 && v1Type == v2Type) {
|
||||
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
|
||||
loc, adaptor.getV1(), adaptor.getV2(),
|
||||
LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
|
||||
llvm::to_vector_of<int32_t>(mask));
|
||||
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
|
||||
return success();
|
||||
}
|
||||
@@ -1029,8 +1029,7 @@ public:
|
||||
eltType = cast<VectorType>(llvmType).getElementType();
|
||||
Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
|
||||
int64_t insPos = 0;
|
||||
for (const auto &en : llvm::enumerate(maskArrayAttr)) {
|
||||
int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
|
||||
for (int64_t extPos : mask) {
|
||||
Value value = adaptor.getV1();
|
||||
if (extPos >= v1Dim) {
|
||||
extPos -= v1Dim;
|
||||
|
||||
@@ -527,10 +527,7 @@ struct VectorShuffleOpConvert final
|
||||
return rewriter.notifyMatchFailure(shuffleOp,
|
||||
"unsupported result vector type");
|
||||
|
||||
SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
|
||||
shuffleOp.getMask(), [](Attribute attr) -> int32_t {
|
||||
return cast<IntegerAttr>(attr).getValue().getZExtValue();
|
||||
});
|
||||
auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
|
||||
|
||||
VectorType oldV1Type = shuffleOp.getV1VectorType();
|
||||
VectorType oldV2Type = shuffleOp.getV2VectorType();
|
||||
|
||||
@@ -2464,11 +2464,6 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
// ShuffleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
|
||||
Value v2, ArrayRef<int64_t> mask) {
|
||||
build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
|
||||
}
|
||||
|
||||
LogicalResult ShuffleOp::verify() {
|
||||
VectorType resultType = getResultVectorType();
|
||||
VectorType v1Type = getV1VectorType();
|
||||
@@ -2491,8 +2486,8 @@ LogicalResult ShuffleOp::verify() {
|
||||
return emitOpError("dimension mismatch");
|
||||
}
|
||||
// Verify mask length.
|
||||
auto maskAttr = getMask().getValue();
|
||||
int64_t maskLength = maskAttr.size();
|
||||
ArrayRef<int64_t> mask = getMask();
|
||||
int64_t maskLength = mask.size();
|
||||
if (maskLength <= 0)
|
||||
return emitOpError("invalid mask length");
|
||||
if (maskLength != resultType.getDimSize(0))
|
||||
@@ -2500,10 +2495,9 @@ LogicalResult ShuffleOp::verify() {
|
||||
// Verify all indices.
|
||||
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
|
||||
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
|
||||
for (const auto &en : llvm::enumerate(maskAttr)) {
|
||||
auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
|
||||
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
|
||||
return emitOpError("mask index #") << (en.index() + 1) << " out of range";
|
||||
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
|
||||
if (maskPos < 0 || maskPos >= indexSize)
|
||||
return emitOpError("mask index #") << (idx + 1) << " out of range";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@@ -2527,13 +2521,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
|
||||
uint64_t expected = begin;
|
||||
return idxArr.size() == width &&
|
||||
llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
|
||||
[&expected](auto attr) {
|
||||
return attr.getZExtValue() == expected++;
|
||||
});
|
||||
template <typename T>
|
||||
static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
|
||||
T expected = begin;
|
||||
return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
|
||||
return value == expected++;
|
||||
});
|
||||
}
|
||||
|
||||
OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
|
||||
@@ -2568,8 +2561,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
|
||||
SmallVector<Attribute> results;
|
||||
auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
|
||||
auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
|
||||
for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
|
||||
int64_t i = index.getZExtValue();
|
||||
for (int64_t i : this->getMask()) {
|
||||
if (i >= lhsSize) {
|
||||
results.push_back(rhsElements[i - lhsSize]);
|
||||
} else {
|
||||
@@ -2590,13 +2582,13 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
|
||||
LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
VectorType v1VectorType = shuffleOp.getV1VectorType();
|
||||
ArrayAttr mask = shuffleOp.getMask();
|
||||
ArrayRef<int64_t> mask = shuffleOp.getMask();
|
||||
if (v1VectorType.getRank() > 0)
|
||||
return failure();
|
||||
if (mask.size() != 1)
|
||||
return failure();
|
||||
VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
|
||||
if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
|
||||
if (mask[0] == 0)
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
|
||||
shuffleOp.getV1());
|
||||
else
|
||||
@@ -2651,11 +2643,11 @@ public:
|
||||
op, "ShuffleOp types don't match an interleave");
|
||||
}
|
||||
|
||||
ArrayAttr shuffleMask = op.getMask();
|
||||
ArrayRef<int64_t> shuffleMask = op.getMask();
|
||||
int64_t resultVectorSize = resultType.getNumElements();
|
||||
for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
|
||||
int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
|
||||
int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
|
||||
int64_t maskValueA = shuffleMask[i * 2];
|
||||
int64_t maskValueB = shuffleMask[(i * 2) + 1];
|
||||
if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"ShuffleOp mask not interleaving");
|
||||
|
||||
@@ -225,8 +225,7 @@ public:
|
||||
off += stride)
|
||||
offsets.push_back(off);
|
||||
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
|
||||
op.getVector(),
|
||||
rewriter.getI64ArrayAttr(offsets));
|
||||
op.getVector(), offsets);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -232,8 +232,7 @@ struct LinearizeVectorExtractStridedSlice final
|
||||
}
|
||||
// Perform a shuffle to extract the kD vector.
|
||||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
|
||||
extractOp, dstType, srcVector, srcVector,
|
||||
rewriter.getI64ArrayAttr(indices));
|
||||
extractOp, dstType, srcVector, srcVector, indices);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -298,20 +297,17 @@ struct LinearizeVectorShuffle final
|
||||
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
|
||||
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
|
||||
// elements) instead of scalars.
|
||||
ArrayAttr mask = shuffleOp.getMask();
|
||||
ArrayRef<int64_t> mask = shuffleOp.getMask();
|
||||
int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
|
||||
llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
|
||||
for (auto [i, value] :
|
||||
llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
|
||||
|
||||
int64_t v = value.getZExtValue();
|
||||
for (auto [i, value] : llvm::enumerate(mask)) {
|
||||
std::iota(indices.begin() + shuffleSliceLen * i,
|
||||
indices.begin() + shuffleSliceLen * (i + 1),
|
||||
shuffleSliceLen * v);
|
||||
shuffleSliceLen * value);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
|
||||
shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
|
||||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
|
||||
vec2, indices);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -368,8 +364,7 @@ struct LinearizeVectorExtract final
|
||||
llvm::SmallVector<int64_t, 2> indices(size);
|
||||
std::iota(indices.begin(), indices.end(), linearizedOffset);
|
||||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
|
||||
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
|
||||
rewriter.getI64ArrayAttr(indices));
|
||||
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -452,8 +447,7 @@ struct LinearizeVectorInsert final
|
||||
// [offset+srcNumElements, end)
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
|
||||
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
|
||||
rewriter.getI64ArrayAttr(indices));
|
||||
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user