[mlir][vector] Use DenseI64ArrayAttr for ExtractOp/InsertOp positions

`DenseI64ArrayAttr` provides a better API than `I64ArrayAttr`. E.g., accessors returning `ArrayRef<int64_t>` (instead of `ArrayAttr`) are generated.

Differential Revision: https://reviews.llvm.org/D156684
This commit is contained in:
Matthias Springer
2023-07-31 15:21:29 +02:00
parent aba0ef7059
commit 16b75cd2bb
14 changed files with 100 additions and 163 deletions

View File

@@ -573,7 +573,7 @@ def Vector_ExtractOp :
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
InferTypeOpAdaptorWithIsCompatible]>,
Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>,
Results<(outs AnyType)> {
let summary = "extract operation";
let description = [{
@@ -589,7 +589,6 @@ def Vector_ExtractOp :
```
}];
let builders = [
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
// Convenience builder which assumes the values in `position` are defined by
// ConstantIndexOp.
OpBuilder<(ins "Value":$source, "ValueRange":$position)>
@@ -689,7 +688,7 @@ def Vector_InsertOp :
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>,
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>,
Results<(outs AnyVectorOfAnyRank:$res)> {
let summary = "insert operation";
let description = [{
@@ -711,8 +710,6 @@ def Vector_InsertOp :
}];
let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<int64_t>":$position)>,
// Convenience builder which assumes all values are constant indices.
OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
];

View File

@@ -807,8 +807,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
op.getSource(), newIndices);
result = rewriter.create<vector::InsertOp>(loc, el, result,
rewriter.getI64ArrayAttr(i));
result = rewriter.create<vector::InsertOp>(loc, el, result, i);
}
} else {
if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
@@ -832,7 +831,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
op.getSource(), newIndices);
result = rewriter.create<vector::InsertOp>(
op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx}));
op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
}
}
}

View File

@@ -1025,44 +1025,37 @@ public:
auto loc = extractOp->getLoc();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
auto positionArrayAttr = extractOp.getPosition();
ArrayRef<int64_t> positionArray = extractOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
return failure();
// Extract entire vector. Should be handled by folder, but just to be safe.
if (positionArrayAttr.empty()) {
if (positionArray.empty()) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
// One-shot extraction of vector from array (only requires extractvalue).
if (isa<VectorType>(resultType)) {
SmallVector<int64_t> indices;
for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
indices.push_back(idx.getInt());
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, adaptor.getVector(), indices);
loc, adaptor.getVector(), positionArray);
rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getVector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
SmallVector<int64_t> nMinusOnePosition;
for (auto idx : positionAttrs.drop_back())
nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt());
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
nMinusOnePosition);
if (positionArray.size() > 1) {
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, extracted, positionArray.drop_back());
}
// Remaining extraction of element from 1-D LLVM vector
auto position = cast<IntegerAttr>(positionAttrs.back());
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
auto constant =
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(extractOp, extracted);
@@ -1147,7 +1140,7 @@ public:
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
auto positionArrayAttr = insertOp.getPosition();
ArrayRef<int64_t> positionArray = insertOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
@@ -1155,7 +1148,7 @@ public:
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
if (positionArrayAttr.empty()) {
if (positionArray.empty()) {
rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
@@ -1163,36 +1156,32 @@ public:
// One-shot insertion of a vector into an array (only requires insertvalue).
if (isa<VectorType>(sourceType)) {
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), adaptor.getSource(),
LLVM::convertArrayToIndices(positionArrayAttr));
loc, adaptor.getDest(), adaptor.getSource(), positionArray);
rewriter.replaceOp(insertOp, inserted);
return success();
}
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getDest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = cast<IntegerAttr>(positionAttrs.back());
auto oneDVectorType = destVectorType;
if (positionAttrs.size() > 1) {
if (positionArray.size() > 1) {
oneDVectorType = reducedVectorTypeBack(destVectorType);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, extracted,
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
loc, extracted, positionArray.drop_back());
}
// Insertion of an element into a 1-D LLVM vector.
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
auto constant =
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
adaptor.getSource(), constant);
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
if (positionArray.size() > 1) {
inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), inserted,
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
loc, adaptor.getDest(), inserted, positionArray.drop_back());
}
rewriter.replaceOp(insertOp, inserted);

View File

@@ -886,10 +886,9 @@ struct UnrollTransferReadConversion
/// vector::InsertOp, return that operation's indices.
void getInsertionIndices(TransferReadOp xferOp,
SmallVector<int64_t, 8> &indices) const {
if (auto insertOp = getInsertOp(xferOp)) {
for (Attribute attr : insertOp.getPosition())
indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
}
if (auto insertOp = getInsertOp(xferOp))
indices.assign(insertOp.getPosition().begin(),
insertOp.getPosition().end());
}
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
@@ -1013,10 +1012,9 @@ struct UnrollTransferWriteConversion
/// indices.
void getExtractionIndices(TransferWriteOp xferOp,
SmallVector<int64_t, 8> &indices) const {
if (auto extractOp = getExtractOp(xferOp)) {
for (Attribute attr : extractOp.getPosition())
indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
}
if (auto extractOp = getExtractOp(xferOp))
indices.assign(extractOp.getPosition().begin(),
extractOp.getPosition().end());
}
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds

View File

@@ -152,7 +152,7 @@ struct VectorExtractOpConvert final
return success();
}
int32_t id = getFirstIntValue(extractOp.getPosition());
int32_t id = extractOp.getPosition()[0];
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.getVector(), id);
return success();
@@ -232,7 +232,7 @@ struct VectorInsertOpConvert final
return success();
}
int32_t id = getFirstIntValue(insertOp.getPosition());
int32_t id = insertOp.getPosition()[0];
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.getSource(), adaptor.getDest(), id);
return success();

View File

@@ -385,8 +385,7 @@ struct ElideUnitDimsInMultiDimReduction
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
auto zeroAttr =
rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0));
SmallVector<int64_t> zeroAttr(shape.size(), 0);
if (mask)
mask = rewriter.create<vector::ExtractOp>(loc, rewriter.getI1Type(),
mask, zeroAttr);
@@ -560,12 +559,10 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
} else {
if (mask) {
mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask,
rewriter.getI64ArrayAttr(0));
mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask, 0);
}
result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
reductionOp.getVector(),
rewriter.getI64ArrayAttr(0));
reductionOp.getVector(), 0);
}
if (Value acc = reductionOp.getAcc())
@@ -1129,18 +1126,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// ExtractOp
//===----------------------------------------------------------------------===//
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, ArrayRef<int64_t> position) {
build(builder, result, source, getVectorSubscriptAttr(builder, position));
}
// Convenience builder which assumes the values are constant indices.
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, ValueRange position) {
SmallVector<int64_t, 4> positionConstants =
llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
return getConstantIntValue(pos).value();
}));
SmallVector<int64_t> positionConstants = llvm::to_vector(llvm::map_range(
position, [](Value pos) { return getConstantIntValue(pos).value(); }));
build(builder, result, source, positionConstants);
}
@@ -1175,15 +1165,13 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult vector::ExtractOp::verify() {
auto positionAttr = getPosition().getValue();
if (positionAttr.size() >
static_cast<unsigned>(getSourceVectorType().getRank()))
ArrayRef<int64_t> position = getPosition();
if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
return emitOpError(
"expected position attribute of rank no greater than vector rank");
for (const auto &en : llvm::enumerate(positionAttr)) {
auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
if (!attr || attr.getInt() < 0 ||
attr.getInt() >= getSourceVectorType().getDimSize(en.index()))
for (const auto &en : llvm::enumerate(position)) {
if (en.value() < 0 ||
en.value() >= getSourceVectorType().getDimSize(en.index()))
return emitOpError("expected position attribute #")
<< (en.index() + 1)
<< " to be a non-negative integer smaller than the corresponding "
@@ -1207,18 +1195,18 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
SmallVector<int64_t, 4> globalPosition;
ExtractOp currentOp = extractOp;
auto extrPos = extractVector<int64_t>(currentOp.getPosition());
ArrayRef<int64_t> extrPos = currentOp.getPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
auto extrPos = extractVector<int64_t>(currentOp.getPosition());
ArrayRef<int64_t> extrPos = currentOp.getPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
extractOp.setOperand(currentOp.getVector());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
std::reverse(globalPosition.begin(), globalPosition.end());
extractOp.setPositionAttr(b.getI64ArrayAttr(globalPosition));
extractOp.setPosition(globalPosition);
return success();
}
@@ -1329,7 +1317,8 @@ ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
sentinels.reserve(vectorRank - extractedRank);
for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
sentinels.push_back(-(i + 1));
extractPosition = extractVector<int64_t>(extractOp.getPosition());
extractPosition.assign(extractOp.getPosition().begin(),
extractOp.getPosition().end());
llvm::append_range(extractPosition, sentinels);
}
@@ -1349,9 +1338,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
Value &res) {
auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
if (ArrayRef(insertedPos) !=
llvm::ArrayRef(extractPosition).take_front(extractedRank))
ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
return failure();
// Case 2.a. early-exit fold.
res = nextInsertOp.getSource();
@@ -1364,7 +1352,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
/// This method updates the internal state.
LogicalResult
ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
if (!isContainedWithin(insertedPos, extractPosition))
return failure();
// Set leading dims to zero.
@@ -1390,9 +1378,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
return Value();
// Otherwise, fold by updating the op inplace and return its result.
OpBuilder b(extractOp.getContext());
extractOp->setAttr(
extractOp.getPositionAttrName(),
b.getI64ArrayAttr(ArrayRef(extractPosition).take_front(extractedRank)));
extractOp.setPosition(ArrayRef(extractPosition).take_front(extractedRank));
extractOp.getVectorMutable().assign(source);
return extractOp.getResult();
}
@@ -1422,7 +1408,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
// Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
// values. This is a more difficult case and we bail.
auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
if (isContainedWithin(extractPosition, insertedPos) ||
intersectsWhereNonNegative(extractPosition, insertedPos))
return Value();
@@ -1487,7 +1473,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
// extract position to `0` when extracting from the source operand.
llvm::SetVector<int64_t> broadcastedUnitDims =
broadcastOp.computeBroadcastedUnitDims();
auto extractPos = extractVector<int64_t>(extractOp.getPosition());
SmallVector<int64_t> extractPos(extractOp.getPosition());
for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
if (broadcastedUnitDims.contains(i))
extractPos[i] = 0;
@@ -1498,7 +1484,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setOperand(source);
extractOp.setPositionAttr(b.getI64ArrayAttr(extractPos));
extractOp.setPosition(extractPos);
return extractOp.getResult();
}
@@ -1537,7 +1523,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
}
// Extract the strides associated with the extract op vector source. Then use
// this to calculate a linearized position for the extract.
auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
SmallVector<int64_t> extractedPos(extractOp.getPosition());
std::reverse(extractedPos.begin(), extractedPos.end());
SmallVector<int64_t, 4> strides;
int64_t stride = 1;
@@ -1563,7 +1549,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setPositionAttr(b.getI64ArrayAttr(newPosition));
extractOp.setPosition(newPosition);
extractOp.setOperand(shapeCastOp.getSource());
return extractOp.getResult();
}
@@ -1603,14 +1589,14 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
sliceOffsets.size())
return Value();
auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
SmallVector<int64_t> extractedPos(extractOp.getPosition());
assert(extractedPos.size() >= sliceOffsets.size());
for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
extractedPos[i] = extractedPos[i] + sliceOffsets[i];
extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setPositionAttr(b.getI64ArrayAttr(extractedPos));
extractOp.setPosition(extractedPos);
return extractOp.getResult();
}
@@ -1635,7 +1621,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
if (destinationRank > insertOp.getSourceVectorType().getRank())
return Value();
auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
auto extractOffsets = extractVector<int64_t>(extractOp.getPosition());
ArrayRef<int64_t> extractOffsets = extractOp.getPosition();
if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
return llvm::cast<IntegerAttr>(attr).getInt() != 1;
@@ -1675,7 +1661,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
extractOp.getVectorMutable().assign(insertOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setPositionAttr(b.getI64ArrayAttr(offsetDiffs));
extractOp.setPosition(offsetDiffs);
return extractOp.getResult();
}
// If the chunk extracted is disjoint from the chunk inserted, keep
@@ -1795,7 +1781,7 @@ public:
// Calculate the linearized position of the continuous chunk of elements to
// extract.
llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
copy(extractOp.getPosition(), completePositions.begin());
int64_t elemBeginPosition =
linearize(completePositions, computeStrides(vecTy.getShape()));
auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
@@ -2288,14 +2274,6 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
Value dest, ArrayRef<int64_t> position) {
result.addOperands({source, dest});
auto positionAttr = getVectorSubscriptAttr(builder, position);
result.addTypes(dest.getType());
result.addAttribute(InsertOp::getPositionAttrName(result.name), positionAttr);
}
// Convenience builder which assumes the values are constant indices.
void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
Value dest, ValueRange position) {
@@ -2307,25 +2285,24 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
}
LogicalResult InsertOp::verify() {
auto positionAttr = getPosition().getValue();
ArrayRef<int64_t> position = getPosition();
auto destVectorType = getDestVectorType();
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
"expected position attribute of rank no greater than dest vector rank");
auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
if (srcVectorType &&
(static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
(static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
return emitOpError("expected position attribute rank + source rank to "
"match dest vector rank");
if (!srcVectorType &&
(positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
(position.size() != static_cast<unsigned>(destVectorType.getRank())))
return emitOpError(
"expected position attribute rank to match the dest vector rank");
for (const auto &en : llvm::enumerate(positionAttr)) {
auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
if (!attr || attr.getInt() < 0 ||
attr.getInt() >= destVectorType.getDimSize(en.index()))
for (const auto &en : llvm::enumerate(position)) {
int64_t attr = en.value();
if (attr < 0 || attr >= destVectorType.getDimSize(en.index()))
return emitOpError("expected position attribute #")
<< (en.index() + 1)
<< " to be a non-negative integer smaller than the corresponding "
@@ -2412,7 +2389,7 @@ public:
// Calculate the linearized position of the continuous chunk of elements to
// insert.
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
copy(getI64SubArray(op.getPosition()), completePositions.begin());
copy(op.getPosition(), completePositions.begin());
int64_t insertBeginPosition =
linearize(completePositions, computeStrides(destTy.getShape()));

View File

@@ -91,10 +91,8 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
return val;
Type lowType = VectorType::Builder(type).dropDim(0);
// At extraction dimension?
if (index == 0) {
auto posAttr = rewriter.getI64ArrayAttr(pos);
return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
}
if (index == 0)
return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
// Unroll leading dimensions.
VectorType vType = cast<VectorType>(lowType);
Type resType = VectorType::Builder(type).dropDim(index);
@@ -102,11 +100,10 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
Value result = rewriter.create<arith::ConstantOp>(
loc, resVectorType, rewriter.getZeroAttr(resVectorType));
for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
posAttr);
result =
rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, d);
}
return result;
}
@@ -120,20 +117,17 @@ static Value reshapeStore(Location loc, Value val, Value result,
if (index == -1)
return val;
// At insertion dimension?
if (index == 0) {
auto posAttr = rewriter.getI64ArrayAttr(pos);
return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
}
if (index == 0)
return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
// Unroll leading dimensions.
Type lowType = VectorType::Builder(type).dropDim(0);
VectorType vType = cast<VectorType>(lowType);
Type insType = VectorType::Builder(vType).dropDim(0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, d);
Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
}
return result;
}
@@ -823,10 +817,8 @@ struct ContractOpToElementwise
newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
newLhs = rewriter.create<vector::ExtractOp>(
loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
newRhs = rewriter.create<vector::ExtractOp>(
loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
std::optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
contractOp.getKind(), rewriter, isInt);
@@ -1167,21 +1159,20 @@ public:
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d);
Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value r = nullptr;
if (acc)
r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
r = rewriter.create<vector::ExtractOp>(loc, acc, d);
Value extrMask;
if (mask)
extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
std::optional<Value> m = createContractArithOp(
loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
if (!m.has_value())
return failure();
result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, d);
}
rewriter.replaceOp(rootOp, result);

View File

@@ -77,9 +77,7 @@ public:
Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
bnd, idx);
Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
auto pos = rewriter.getI64ArrayAttr(d);
result =
rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
result = rewriter.create<vector::InsertOp>(loc, dstType, sel, result, d);
}
rewriter.replaceOp(op, result);
return success();
@@ -151,11 +149,9 @@ public:
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDim; d++) {
auto pos = rewriter.getI64ArrayAttr(d);
for (int64_t d = 0; d < trueDim; d++)
result =
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
}
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
rewriter.replaceOp(op, result);
return success();
}

View File

@@ -944,7 +944,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Rewrite vector.extract with 1d source to vector.extractelement.
if (extractSrcType.getRank() == 1) {
assert(extractOp.getPosition().size() == 1 && "expected 1 index");
int64_t pos = cast<IntegerAttr>(extractOp.getPosition()[0]).getInt();
int64_t pos = extractOp.getPosition()[0];
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
extractOp, extractOp.getVector(),
@@ -1201,7 +1201,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Rewrite vector.insert with 1d dest to vector.insertelement.
if (insertOp.getDestVectorType().getRank() == 1) {
assert(insertOp.getPosition().size() == 1 && "expected 1 index");
int64_t pos = cast<IntegerAttr>(insertOp.getPosition()[0]).getInt();
int64_t pos = insertOp.getPosition()[0];
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
insertOp, insertOp.getSource(), insertOp.getDest(),
@@ -1276,10 +1276,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
} else {
// One lane inserts the entire source vector.
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
SmallVector<int64_t> newPos = llvm::to_vector(
llvm::map_range(insertOp.getPosition(), [](Attribute attr) {
return cast<IntegerAttr>(attr).getInt();
}));
SmallVector<int64_t> newPos(insertOp.getPosition());
// tid of inserting lane: pos / elementsPerLane
Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
loc, newPos[distrDestDim] / elementsPerLane);

View File

@@ -165,16 +165,14 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
// type has leading unit dims, we also trim the position array accordingly,
// then (2) if source type also has leading unit dims, we need to append
// zeroes to the position array accordingly.
unsigned oldPosRank = insertOp.getPosition().getValue().size();
unsigned oldPosRank = insertOp.getPosition().size();
unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
SmallVector<Attribute> newPositions = llvm::to_vector(
insertOp.getPosition().getValue().take_back(newPosRank));
newPositions.resize(newDstType.getRank() - newSrcRank,
rewriter.getI64IntegerAttr(0));
SmallVector<int64_t> newPositions =
llvm::to_vector(insertOp.getPosition().take_back(newPosRank));
newPositions.resize(newDstType.getRank() - newSrcRank, 0);
auto newInsertOp = rewriter.create<vector::InsertOp>(
loc, newDstType, newSrcVector, newDstVector,
rewriter.getArrayAttr(newPositions));
loc, newDstType, newSrcVector, newDstVector, newPositions);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);

View File

@@ -704,7 +704,7 @@ class RewriteScalarExtractOfTransferRead
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
int64_t offset = cast<IntegerAttr>(it.value()).getInt();
int64_t offset = it.value();
int64_t idx =
newIndices.size() - extractOp.getPosition().size() + it.index();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(

View File

@@ -598,11 +598,7 @@ struct BubbleDownVectorBitCastForExtract
unsigned expandRatio =
castDstType.getNumElements() / castSrcType.getNumElements();
auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
};
uint64_t index = getFirstIntValue(extractOp.getPosition());
uint64_t index = extractOp.getPosition()[0];
// Get the single scalar (as a vector) in the source value that packs the
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
@@ -610,7 +606,7 @@ struct BubbleDownVectorBitCastForExtract
VectorType::get({1}, castSrcType.getElementType());
Value packedValue = rewriter.create<vector::ExtractOp>(
extractOp.getLoc(), oneScalarType, castOp.getSource(),
rewriter.getI64ArrayAttr(index / expandRatio));
index / expandRatio);
// Cast it to a vector with the desired scalar's type.
// E.g. f32 -> vector<2xf16>
@@ -621,8 +617,7 @@ struct BubbleDownVectorBitCastForExtract
// Finally extract the desired scalar.
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
extractOp, extractOp.getType(), castedValue,
rewriter.getI64ArrayAttr(index % expandRatio));
extractOp, extractOp.getType(), castedValue, index % expandRatio);
return success();
}

View File

@@ -155,8 +155,8 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
%0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32>
%1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32
%0 = "vector.extract"(%arg0) <{position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
%1 = "vector.extract"(%arg0) <{position = array<i64: 1>}> : (vector<2xf32>) -> f32
return %0, %1: vector<1xf32>, f32
}

View File

@@ -133,7 +133,7 @@ func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected position attribute of rank no greater than vector rank}}
%1 = "vector.extract" (%arg0) { position = [0, 0, 0, 0] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
%1 = "vector.extract" (%arg0) <{position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
}
// -----