mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 19:44:38 +08:00
[mlir][vector] Use DenseI64ArrayAttr for constant_mask dim sizes (#100997)
This prevents a bunch of boilerplate conversions to/from IntegerAttrs and int64_ts. Other than that this is a NFC.
This commit is contained in:
@@ -2443,7 +2443,7 @@ def Vector_TypeCastOp :
|
||||
|
||||
def Vector_ConstantMaskOp :
|
||||
Vector_Op<"constant_mask", [Pure]>,
|
||||
Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
|
||||
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
|
||||
Results<(outs VectorOfAnyRankOf<[I1]>)> {
|
||||
let summary = "creates a constant vector mask";
|
||||
let description = [{
|
||||
|
||||
@@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
|
||||
// Inspect constant mask index. If the index exceeds the
|
||||
// dimension size, all bits are set. If the index is zero
|
||||
// or less, no bits are set.
|
||||
ArrayAttr masks = m.getMaskDimSizes();
|
||||
ArrayRef<int64_t> masks = m.getMaskDimSizes();
|
||||
auto shape = m.getType().getShape();
|
||||
bool allTrue = true;
|
||||
bool allFalse = true;
|
||||
for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
|
||||
int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
|
||||
if (i < dimSize)
|
||||
if (maskIdx < dimSize)
|
||||
allTrue = false;
|
||||
if (i > 0)
|
||||
if (maskIdx > 0)
|
||||
allFalse = false;
|
||||
}
|
||||
if (allTrue)
|
||||
@@ -3593,8 +3592,7 @@ public:
|
||||
if (extractStridedSliceOp.hasNonUnitStrides())
|
||||
return failure();
|
||||
// Gather constant mask dimension sizes.
|
||||
SmallVector<int64_t, 4> maskDimSizes;
|
||||
populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
|
||||
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
|
||||
// Gather strided slice offsets and sizes.
|
||||
SmallVector<int64_t, 4> sliceOffsets;
|
||||
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
|
||||
@@ -3625,7 +3623,7 @@ public:
|
||||
// region.
|
||||
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
|
||||
extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
|
||||
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
|
||||
sliceMaskDimSizes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -5410,21 +5408,19 @@ public:
|
||||
}
|
||||
|
||||
if (constantMaskOp) {
|
||||
auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
|
||||
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
|
||||
auto numMaskOperands = maskDimSizes.size();
|
||||
|
||||
// Check every mask dim size to see whether it can be dropped
|
||||
for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
|
||||
--i) {
|
||||
if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
|
||||
if (maskDimSizes[i] != 1)
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
|
||||
ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
|
||||
newMaskOperandsAttr);
|
||||
newMaskOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -5804,12 +5800,10 @@ public:
|
||||
|
||||
// ConstantMaskOp case.
|
||||
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
|
||||
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
|
||||
applyPermutationToVector(newMaskDimSizes, permutation);
|
||||
auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
|
||||
transpOp, transpOp.getResultVectorType(),
|
||||
ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
|
||||
transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
|
||||
if (resultType.getRank() == 0) {
|
||||
if (getMaskDimSizes().size() != 1)
|
||||
return emitError("array attr must have length 1 for 0-D vectors");
|
||||
auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
|
||||
auto dim = getMaskDimSizes()[0];
|
||||
if (dim != 0 && dim != 1)
|
||||
return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
|
||||
return success();
|
||||
@@ -5846,9 +5840,8 @@ LogicalResult ConstantMaskOp::verify() {
|
||||
// result dimension size.
|
||||
auto resultShape = resultType.getShape();
|
||||
auto resultScalableDims = resultType.getScalableDims();
|
||||
SmallVector<int64_t, 4> maskDimSizes;
|
||||
for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
|
||||
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
|
||||
ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
|
||||
for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
|
||||
if (maskDimSize < 0 || maskDimSize > resultShape[index])
|
||||
return emitOpError(
|
||||
"array attr of size out of bounds of vector result dimension size");
|
||||
@@ -5856,7 +5849,6 @@ LogicalResult ConstantMaskOp::verify() {
|
||||
maskDimSize != resultShape[index])
|
||||
return emitOpError(
|
||||
"only supports 'none set' or 'all set' scalable dimensions");
|
||||
maskDimSizes.push_back(maskDimSize);
|
||||
}
|
||||
// Verify that if one mask dim size is zero, they all should be zero (because
|
||||
// the mask region is a conjunction of each mask dimension interval).
|
||||
@@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
|
||||
// Check the corner case of 0-D vectors first.
|
||||
if (resultType.getRank() == 0) {
|
||||
assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
|
||||
return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
|
||||
return getMaskDimSizes()[0] == 1;
|
||||
}
|
||||
for (const auto [resultSize, intAttr] :
|
||||
for (const auto [resultSize, maskDimSize] :
|
||||
llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
|
||||
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
|
||||
if (maskDimSize < resultSize)
|
||||
return false;
|
||||
}
|
||||
@@ -6007,9 +5998,8 @@ public:
|
||||
}
|
||||
|
||||
// Replace 'createMaskOp' with ConstantMaskOp.
|
||||
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
|
||||
createMaskOp, retTy,
|
||||
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
|
||||
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
|
||||
maskDimSizes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -111,7 +111,7 @@ public:
|
||||
if (rank == 0) {
|
||||
assert(dimSizes.size() == 1 &&
|
||||
"Expected exactly one dim size for a 0-D vector");
|
||||
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
|
||||
bool value = dimSizes.front() == 1;
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, dstType,
|
||||
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
|
||||
@@ -119,7 +119,7 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
|
||||
int64_t trueDimSize = dimSizes.front();
|
||||
|
||||
if (rank == 1) {
|
||||
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
|
||||
@@ -147,7 +147,7 @@ public:
|
||||
|
||||
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
|
||||
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
|
||||
loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
|
||||
loc, lowType, dimSizes.drop_front());
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, dstType, rewriter.getZeroAttr(dstType));
|
||||
for (int64_t d = 0; d < trueDimSize; d++)
|
||||
|
||||
@@ -550,9 +550,7 @@ struct CastAwayConstantMaskLeadingOneDim
|
||||
return failure();
|
||||
|
||||
int64_t dropDim = oldType.getRank() - newType.getRank();
|
||||
SmallVector<int64_t> dimSizes;
|
||||
for (auto attr : mask.getMaskDimSizes())
|
||||
dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
|
||||
ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
|
||||
|
||||
// If any of the dropped unit dims has a size of `0`, the entire mask is a
|
||||
// zero mask, else the unit dim has no effect on the mask.
|
||||
@@ -563,7 +561,7 @@ struct CastAwayConstantMaskLeadingOneDim
|
||||
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
|
||||
|
||||
auto newMask = rewriter.create<vector::ConstantMaskOp>(
|
||||
mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
|
||||
mask.getLoc(), newType, newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -83,17 +83,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
|
||||
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
|
||||
newMaskOperands);
|
||||
} else if (constantMaskOp) {
|
||||
ArrayRef<Attribute> maskDimSizes =
|
||||
constantMaskOp.getMaskDimSizes().getValue();
|
||||
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
|
||||
size_t numMaskOperands = maskDimSizes.size();
|
||||
auto origIndex =
|
||||
cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
|
||||
IntegerAttr maskIndexAttr =
|
||||
rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
|
||||
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
|
||||
newMaskDimSizes.push_back(maskIndexAttr);
|
||||
newMask = rewriter.create<vector::ConstantMaskOp>(
|
||||
loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
|
||||
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
|
||||
int64_t maskIndex = (origIndex + scale - 1) / scale;
|
||||
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
|
||||
newMaskDimSizes.push_back(maskIndex);
|
||||
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
|
||||
newMaskDimSizes);
|
||||
}
|
||||
|
||||
while (!extractOps.empty()) {
|
||||
|
||||
Reference in New Issue
Block a user