[mlir][vector] Use DenseI64ArrayAttr for constant_mask dim sizes (#100997)

This prevents a bunch of boilerplate conversions to/from IntegerAttrs
and int64_ts. Other than that this is a NFC.
This commit is contained in:
Benjamin Maxwell
2024-07-29 18:08:37 +01:00
committed by GitHub
parent 135a1e90a3
commit 0d9b439408
5 changed files with 30 additions and 45 deletions

View File

@@ -2443,7 +2443,7 @@ def Vector_TypeCastOp :
def Vector_ConstantMaskOp :
Vector_Op<"constant_mask", [Pure]>,
Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";
let description = [{

View File

@@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
// Inspect constant mask index. If the index exceeds the
// dimension size, all bits are set. If the index is zero
// or less, no bits are set.
ArrayAttr masks = m.getMaskDimSizes();
ArrayRef<int64_t> masks = m.getMaskDimSizes();
auto shape = m.getType().getShape();
bool allTrue = true;
bool allFalse = true;
for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
if (i < dimSize)
if (maskIdx < dimSize)
allTrue = false;
if (i > 0)
if (maskIdx > 0)
allFalse = false;
}
if (allTrue)
@@ -3593,8 +3592,7 @@ public:
if (extractStridedSliceOp.hasNonUnitStrides())
return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
// Gather strided slice offsets and sizes.
SmallVector<int64_t, 4> sliceOffsets;
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
@@ -3625,7 +3623,7 @@ public:
// region.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
sliceMaskDimSizes);
return success();
}
};
@@ -5410,21 +5408,19 @@ public:
}
if (constantMaskOp) {
auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
auto numMaskOperands = maskDimSizes.size();
// Check every mask dim size to see whether it can be dropped
for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
--i) {
if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
if (maskDimSizes[i] != 1)
return failure();
}
auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
newMaskOperandsAttr);
newMaskOperands);
return success();
}
@@ -5804,12 +5800,10 @@ public:
// ConstantMaskOp case.
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
applyPermutationToVector(newMaskDimSizes, permutation);
auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
transpOp, transpOp.getResultVectorType(),
ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
return success();
}
};
@@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
if (resultType.getRank() == 0) {
if (getMaskDimSizes().size() != 1)
return emitError("array attr must have length 1 for 0-D vectors");
auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
auto dim = getMaskDimSizes()[0];
if (dim != 0 && dim != 1)
return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
return success();
@@ -5846,9 +5840,8 @@ LogicalResult ConstantMaskOp::verify() {
// result dimension size.
auto resultShape = resultType.getShape();
auto resultScalableDims = resultType.getScalableDims();
SmallVector<int64_t, 4> maskDimSizes;
for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
if (maskDimSize < 0 || maskDimSize > resultShape[index])
return emitOpError(
"array attr of size out of bounds of vector result dimension size");
@@ -5856,7 +5849,6 @@ LogicalResult ConstantMaskOp::verify() {
maskDimSize != resultShape[index])
return emitOpError(
"only supports 'none set' or 'all set' scalable dimensions");
maskDimSizes.push_back(maskDimSize);
}
// Verify that if one mask dim size is zero, they all should be zero (because
// the mask region is a conjunction of each mask dimension interval).
@@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
return getMaskDimSizes()[0] == 1;
}
for (const auto [resultSize, intAttr] :
for (const auto [resultSize, maskDimSize] :
llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
if (maskDimSize < resultSize)
return false;
}
@@ -6007,9 +5998,8 @@ public:
}
// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
createMaskOp, retTy,
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
maskDimSizes);
return success();
}
};

View File

@@ -111,7 +111,7 @@ public:
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
bool value = dimSizes.front() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
@@ -119,7 +119,7 @@ public:
return success();
}
int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
int64_t trueDimSize = dimSizes.front();
if (rank == 1) {
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
@@ -147,7 +147,7 @@ public:
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
loc, lowType, dimSizes.drop_front());
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDimSize; d++)

View File

@@ -550,9 +550,7 @@ struct CastAwayConstantMaskLeadingOneDim
return failure();
int64_t dropDim = oldType.getRank() - newType.getRank();
SmallVector<int64_t> dimSizes;
for (auto attr : mask.getMaskDimSizes())
dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
// If any of the dropped unit dims has a size of `0`, the entire mask is a
// zero mask, else the unit dim has no effect on the mask.
@@ -563,7 +561,7 @@ struct CastAwayConstantMaskLeadingOneDim
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
auto newMask = rewriter.create<vector::ConstantMaskOp>(
mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
mask.getLoc(), newType, newDimSizes);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
return success();
}

View File

@@ -83,17 +83,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
} else if (constantMaskOp) {
ArrayRef<Attribute> maskDimSizes =
constantMaskOp.getMaskDimSizes().getValue();
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
auto origIndex =
cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
IntegerAttr maskIndexAttr =
rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndexAttr);
newMask = rewriter.create<vector::ConstantMaskOp>(
loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
int64_t maskIndex = (origIndex + scale - 1) / scale;
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
newMaskDimSizes.push_back(maskIndex);
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
newMaskDimSizes);
}
while (!extractOps.empty()) {