mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 02:00:03 +08:00
[mlir][vector] Use DenseI64ArrayAttr for ExtractOp/InsertOp positions
`DenseI64ArrayAttr` provides a better API than `I64ArrayAttr`. E.g., accessors returning `ArrayRef<int64_t>` (instead of `ArrayAttr`) are generated. Differential Revision: https://reviews.llvm.org/D156684
This commit is contained in:
@@ -573,7 +573,7 @@ def Vector_ExtractOp :
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
InferTypeOpAdaptorWithIsCompatible]>,
|
||||
Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
|
||||
Arguments<(ins AnyVectorOfAnyRank:$vector, DenseI64ArrayAttr:$position)>,
|
||||
Results<(outs AnyType)> {
|
||||
let summary = "extract operation";
|
||||
let description = [{
|
||||
@@ -589,7 +589,6 @@ def Vector_ExtractOp :
|
||||
```
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
|
||||
// Convenience builder which assumes the values in `position` are defined by
|
||||
// ConstantIndexOp.
|
||||
OpBuilder<(ins "Value":$source, "ValueRange":$position)>
|
||||
@@ -689,7 +688,7 @@ def Vector_InsertOp :
|
||||
PredOpTrait<"source operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
AllTypesMatch<["dest", "res"]>]>,
|
||||
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>,
|
||||
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, DenseI64ArrayAttr:$position)>,
|
||||
Results<(outs AnyVectorOfAnyRank:$res)> {
|
||||
let summary = "insert operation";
|
||||
let description = [{
|
||||
@@ -711,8 +710,6 @@ def Vector_InsertOp :
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$source, "Value":$dest,
|
||||
"ArrayRef<int64_t>":$position)>,
|
||||
// Convenience builder which assumes all values are constant indices.
|
||||
OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)>
|
||||
];
|
||||
|
||||
@@ -807,8 +807,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
||||
|
||||
Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
|
||||
op.getSource(), newIndices);
|
||||
result = rewriter.create<vector::InsertOp>(loc, el, result,
|
||||
rewriter.getI64ArrayAttr(i));
|
||||
result = rewriter.create<vector::InsertOp>(loc, el, result, i);
|
||||
}
|
||||
} else {
|
||||
if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
|
||||
@@ -832,7 +831,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
||||
Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
|
||||
op.getSource(), newIndices);
|
||||
result = rewriter.create<vector::InsertOp>(
|
||||
op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx}));
|
||||
op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1025,44 +1025,37 @@ public:
|
||||
auto loc = extractOp->getLoc();
|
||||
auto resultType = extractOp.getResult().getType();
|
||||
auto llvmResultType = typeConverter->convertType(resultType);
|
||||
auto positionArrayAttr = extractOp.getPosition();
|
||||
ArrayRef<int64_t> positionArray = extractOp.getPosition();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
return failure();
|
||||
|
||||
// Extract entire vector. Should be handled by folder, but just to be safe.
|
||||
if (positionArrayAttr.empty()) {
|
||||
if (positionArray.empty()) {
|
||||
rewriter.replaceOp(extractOp, adaptor.getVector());
|
||||
return success();
|
||||
}
|
||||
|
||||
// One-shot extraction of vector from array (only requires extractvalue).
|
||||
if (isa<VectorType>(resultType)) {
|
||||
SmallVector<int64_t> indices;
|
||||
for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
|
||||
indices.push_back(idx.getInt());
|
||||
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, adaptor.getVector(), indices);
|
||||
loc, adaptor.getVector(), positionArray);
|
||||
rewriter.replaceOp(extractOp, extracted);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
Value extracted = adaptor.getVector();
|
||||
auto positionAttrs = positionArrayAttr.getValue();
|
||||
if (positionAttrs.size() > 1) {
|
||||
SmallVector<int64_t> nMinusOnePosition;
|
||||
for (auto idx : positionAttrs.drop_back())
|
||||
nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt());
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
|
||||
nMinusOnePosition);
|
||||
if (positionArray.size() > 1) {
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, extracted, positionArray.drop_back());
|
||||
}
|
||||
|
||||
// Remaining extraction of element from 1-D LLVM vector
|
||||
auto position = cast<IntegerAttr>(positionAttrs.back());
|
||||
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
|
||||
auto constant =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
|
||||
extracted =
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
|
||||
rewriter.replaceOp(extractOp, extracted);
|
||||
@@ -1147,7 +1140,7 @@ public:
|
||||
auto sourceType = insertOp.getSourceType();
|
||||
auto destVectorType = insertOp.getDestVectorType();
|
||||
auto llvmResultType = typeConverter->convertType(destVectorType);
|
||||
auto positionArrayAttr = insertOp.getPosition();
|
||||
ArrayRef<int64_t> positionArray = insertOp.getPosition();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
@@ -1155,7 +1148,7 @@ public:
|
||||
|
||||
// Overwrite entire vector with value. Should be handled by folder, but
|
||||
// just to be safe.
|
||||
if (positionArrayAttr.empty()) {
|
||||
if (positionArray.empty()) {
|
||||
rewriter.replaceOp(insertOp, adaptor.getSource());
|
||||
return success();
|
||||
}
|
||||
@@ -1163,36 +1156,32 @@ public:
|
||||
// One-shot insertion of a vector into an array (only requires insertvalue).
|
||||
if (isa<VectorType>(sourceType)) {
|
||||
Value inserted = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, adaptor.getDest(), adaptor.getSource(),
|
||||
LLVM::convertArrayToIndices(positionArrayAttr));
|
||||
loc, adaptor.getDest(), adaptor.getSource(), positionArray);
|
||||
rewriter.replaceOp(insertOp, inserted);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
Value extracted = adaptor.getDest();
|
||||
auto positionAttrs = positionArrayAttr.getValue();
|
||||
auto position = cast<IntegerAttr>(positionAttrs.back());
|
||||
auto oneDVectorType = destVectorType;
|
||||
if (positionAttrs.size() > 1) {
|
||||
if (positionArray.size() > 1) {
|
||||
oneDVectorType = reducedVectorTypeBack(destVectorType);
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, extracted,
|
||||
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
|
||||
loc, extracted, positionArray.drop_back());
|
||||
}
|
||||
|
||||
// Insertion of an element into a 1-D LLVM vector.
|
||||
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
|
||||
auto constant =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
|
||||
Value inserted = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, typeConverter->convertType(oneDVectorType), extracted,
|
||||
adaptor.getSource(), constant);
|
||||
|
||||
// Potential insertion of resulting 1-D vector into array.
|
||||
if (positionAttrs.size() > 1) {
|
||||
if (positionArray.size() > 1) {
|
||||
inserted = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, adaptor.getDest(), inserted,
|
||||
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
|
||||
loc, adaptor.getDest(), inserted, positionArray.drop_back());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(insertOp, inserted);
|
||||
|
||||
@@ -886,10 +886,9 @@ struct UnrollTransferReadConversion
|
||||
/// vector::InsertOp, return that operation's indices.
|
||||
void getInsertionIndices(TransferReadOp xferOp,
|
||||
SmallVector<int64_t, 8> &indices) const {
|
||||
if (auto insertOp = getInsertOp(xferOp)) {
|
||||
for (Attribute attr : insertOp.getPosition())
|
||||
indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
|
||||
}
|
||||
if (auto insertOp = getInsertOp(xferOp))
|
||||
indices.assign(insertOp.getPosition().begin(),
|
||||
insertOp.getPosition().end());
|
||||
}
|
||||
|
||||
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
|
||||
@@ -1013,10 +1012,9 @@ struct UnrollTransferWriteConversion
|
||||
/// indices.
|
||||
void getExtractionIndices(TransferWriteOp xferOp,
|
||||
SmallVector<int64_t, 8> &indices) const {
|
||||
if (auto extractOp = getExtractOp(xferOp)) {
|
||||
for (Attribute attr : extractOp.getPosition())
|
||||
indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
|
||||
}
|
||||
if (auto extractOp = getExtractOp(xferOp))
|
||||
indices.assign(extractOp.getPosition().begin(),
|
||||
extractOp.getPosition().end());
|
||||
}
|
||||
|
||||
/// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
|
||||
|
||||
@@ -152,7 +152,7 @@ struct VectorExtractOpConvert final
|
||||
return success();
|
||||
}
|
||||
|
||||
int32_t id = getFirstIntValue(extractOp.getPosition());
|
||||
int32_t id = extractOp.getPosition()[0];
|
||||
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
|
||||
extractOp, adaptor.getVector(), id);
|
||||
return success();
|
||||
@@ -232,7 +232,7 @@ struct VectorInsertOpConvert final
|
||||
return success();
|
||||
}
|
||||
|
||||
int32_t id = getFirstIntValue(insertOp.getPosition());
|
||||
int32_t id = insertOp.getPosition()[0];
|
||||
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
|
||||
insertOp, adaptor.getSource(), adaptor.getDest(), id);
|
||||
return success();
|
||||
|
||||
@@ -385,8 +385,7 @@ struct ElideUnitDimsInMultiDimReduction
|
||||
} else {
|
||||
// This means we are reducing all the dimensions, and all reduction
|
||||
// dimensions are of size 1. So a simple extraction would do.
|
||||
auto zeroAttr =
|
||||
rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0));
|
||||
SmallVector<int64_t> zeroAttr(shape.size(), 0);
|
||||
if (mask)
|
||||
mask = rewriter.create<vector::ExtractOp>(loc, rewriter.getI1Type(),
|
||||
mask, zeroAttr);
|
||||
@@ -560,12 +559,10 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
|
||||
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
|
||||
} else {
|
||||
if (mask) {
|
||||
mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask,
|
||||
rewriter.getI64ArrayAttr(0));
|
||||
mask = rewriter.create<ExtractOp>(loc, rewriter.getI1Type(), mask, 0);
|
||||
}
|
||||
result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
|
||||
reductionOp.getVector(),
|
||||
rewriter.getI64ArrayAttr(0));
|
||||
reductionOp.getVector(), 0);
|
||||
}
|
||||
|
||||
if (Value acc = reductionOp.getAcc())
|
||||
@@ -1129,18 +1126,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
|
||||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value source, ArrayRef<int64_t> position) {
|
||||
build(builder, result, source, getVectorSubscriptAttr(builder, position));
|
||||
}
|
||||
|
||||
// Convenience builder which assumes the values are constant indices.
|
||||
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value source, ValueRange position) {
|
||||
SmallVector<int64_t, 4> positionConstants =
|
||||
llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
|
||||
return getConstantIntValue(pos).value();
|
||||
}));
|
||||
SmallVector<int64_t> positionConstants = llvm::to_vector(llvm::map_range(
|
||||
position, [](Value pos) { return getConstantIntValue(pos).value(); }));
|
||||
build(builder, result, source, positionConstants);
|
||||
}
|
||||
|
||||
@@ -1175,15 +1165,13 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
}
|
||||
|
||||
LogicalResult vector::ExtractOp::verify() {
|
||||
auto positionAttr = getPosition().getValue();
|
||||
if (positionAttr.size() >
|
||||
static_cast<unsigned>(getSourceVectorType().getRank()))
|
||||
ArrayRef<int64_t> position = getPosition();
|
||||
if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
|
||||
return emitOpError(
|
||||
"expected position attribute of rank no greater than vector rank");
|
||||
for (const auto &en : llvm::enumerate(positionAttr)) {
|
||||
auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
|
||||
if (!attr || attr.getInt() < 0 ||
|
||||
attr.getInt() >= getSourceVectorType().getDimSize(en.index()))
|
||||
for (const auto &en : llvm::enumerate(position)) {
|
||||
if (en.value() < 0 ||
|
||||
en.value() >= getSourceVectorType().getDimSize(en.index()))
|
||||
return emitOpError("expected position attribute #")
|
||||
<< (en.index() + 1)
|
||||
<< " to be a non-negative integer smaller than the corresponding "
|
||||
@@ -1207,18 +1195,18 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
|
||||
|
||||
SmallVector<int64_t, 4> globalPosition;
|
||||
ExtractOp currentOp = extractOp;
|
||||
auto extrPos = extractVector<int64_t>(currentOp.getPosition());
|
||||
ArrayRef<int64_t> extrPos = currentOp.getPosition();
|
||||
globalPosition.append(extrPos.rbegin(), extrPos.rend());
|
||||
while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
|
||||
currentOp = nextOp;
|
||||
auto extrPos = extractVector<int64_t>(currentOp.getPosition());
|
||||
ArrayRef<int64_t> extrPos = currentOp.getPosition();
|
||||
globalPosition.append(extrPos.rbegin(), extrPos.rend());
|
||||
}
|
||||
extractOp.setOperand(currentOp.getVector());
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
std::reverse(globalPosition.begin(), globalPosition.end());
|
||||
extractOp.setPositionAttr(b.getI64ArrayAttr(globalPosition));
|
||||
extractOp.setPosition(globalPosition);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1329,7 +1317,8 @@ ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
|
||||
sentinels.reserve(vectorRank - extractedRank);
|
||||
for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
|
||||
sentinels.push_back(-(i + 1));
|
||||
extractPosition = extractVector<int64_t>(extractOp.getPosition());
|
||||
extractPosition.assign(extractOp.getPosition().begin(),
|
||||
extractOp.getPosition().end());
|
||||
llvm::append_range(extractPosition, sentinels);
|
||||
}
|
||||
|
||||
@@ -1349,9 +1338,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
|
||||
LogicalResult
|
||||
ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
|
||||
Value &res) {
|
||||
auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
|
||||
if (ArrayRef(insertedPos) !=
|
||||
llvm::ArrayRef(extractPosition).take_front(extractedRank))
|
||||
ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
|
||||
if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
|
||||
return failure();
|
||||
// Case 2.a. early-exit fold.
|
||||
res = nextInsertOp.getSource();
|
||||
@@ -1364,7 +1352,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
|
||||
/// This method updates the internal state.
|
||||
LogicalResult
|
||||
ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
|
||||
auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
|
||||
ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
|
||||
if (!isContainedWithin(insertedPos, extractPosition))
|
||||
return failure();
|
||||
// Set leading dims to zero.
|
||||
@@ -1390,9 +1378,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
|
||||
return Value();
|
||||
// Otherwise, fold by updating the op inplace and return its result.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp->setAttr(
|
||||
extractOp.getPositionAttrName(),
|
||||
b.getI64ArrayAttr(ArrayRef(extractPosition).take_front(extractedRank)));
|
||||
extractOp.setPosition(ArrayRef(extractPosition).take_front(extractedRank));
|
||||
extractOp.getVectorMutable().assign(source);
|
||||
return extractOp.getResult();
|
||||
}
|
||||
@@ -1422,7 +1408,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
|
||||
|
||||
// Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
|
||||
// values. This is a more difficult case and we bail.
|
||||
auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
|
||||
ArrayRef<int64_t> insertedPos = nextInsertOp.getPosition();
|
||||
if (isContainedWithin(extractPosition, insertedPos) ||
|
||||
intersectsWhereNonNegative(extractPosition, insertedPos))
|
||||
return Value();
|
||||
@@ -1487,7 +1473,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
|
||||
// extract position to `0` when extracting from the source operand.
|
||||
llvm::SetVector<int64_t> broadcastedUnitDims =
|
||||
broadcastOp.computeBroadcastedUnitDims();
|
||||
auto extractPos = extractVector<int64_t>(extractOp.getPosition());
|
||||
SmallVector<int64_t> extractPos(extractOp.getPosition());
|
||||
for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
|
||||
if (broadcastedUnitDims.contains(i))
|
||||
extractPos[i] = 0;
|
||||
@@ -1498,7 +1484,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp.setOperand(source);
|
||||
extractOp.setPositionAttr(b.getI64ArrayAttr(extractPos));
|
||||
extractOp.setPosition(extractPos);
|
||||
return extractOp.getResult();
|
||||
}
|
||||
|
||||
@@ -1537,7 +1523,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
||||
}
|
||||
// Extract the strides associated with the extract op vector source. Then use
|
||||
// this to calculate a linearized position for the extract.
|
||||
auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
|
||||
SmallVector<int64_t> extractedPos(extractOp.getPosition());
|
||||
std::reverse(extractedPos.begin(), extractedPos.end());
|
||||
SmallVector<int64_t, 4> strides;
|
||||
int64_t stride = 1;
|
||||
@@ -1563,7 +1549,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
||||
SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp.setPositionAttr(b.getI64ArrayAttr(newPosition));
|
||||
extractOp.setPosition(newPosition);
|
||||
extractOp.setOperand(shapeCastOp.getSource());
|
||||
return extractOp.getResult();
|
||||
}
|
||||
@@ -1603,14 +1589,14 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
|
||||
if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
|
||||
sliceOffsets.size())
|
||||
return Value();
|
||||
auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
|
||||
SmallVector<int64_t> extractedPos(extractOp.getPosition());
|
||||
assert(extractedPos.size() >= sliceOffsets.size());
|
||||
for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
|
||||
extractedPos[i] = extractedPos[i] + sliceOffsets[i];
|
||||
extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp.setPositionAttr(b.getI64ArrayAttr(extractedPos));
|
||||
extractOp.setPosition(extractedPos);
|
||||
return extractOp.getResult();
|
||||
}
|
||||
|
||||
@@ -1635,7 +1621,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
|
||||
if (destinationRank > insertOp.getSourceVectorType().getRank())
|
||||
return Value();
|
||||
auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
|
||||
auto extractOffsets = extractVector<int64_t>(extractOp.getPosition());
|
||||
ArrayRef<int64_t> extractOffsets = extractOp.getPosition();
|
||||
|
||||
if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
|
||||
return llvm::cast<IntegerAttr>(attr).getInt() != 1;
|
||||
@@ -1675,7 +1661,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
|
||||
extractOp.getVectorMutable().assign(insertOp.getSource());
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp.setPositionAttr(b.getI64ArrayAttr(offsetDiffs));
|
||||
extractOp.setPosition(offsetDiffs);
|
||||
return extractOp.getResult();
|
||||
}
|
||||
// If the chunk extracted is disjoint from the chunk inserted, keep
|
||||
@@ -1795,7 +1781,7 @@ public:
|
||||
// Calculate the linearized position of the continuous chunk of elements to
|
||||
// extract.
|
||||
llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
|
||||
copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
|
||||
copy(extractOp.getPosition(), completePositions.begin());
|
||||
int64_t elemBeginPosition =
|
||||
linearize(completePositions, computeStrides(vecTy.getShape()));
|
||||
auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
|
||||
@@ -2288,14 +2274,6 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
|
||||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
|
||||
Value dest, ArrayRef<int64_t> position) {
|
||||
result.addOperands({source, dest});
|
||||
auto positionAttr = getVectorSubscriptAttr(builder, position);
|
||||
result.addTypes(dest.getType());
|
||||
result.addAttribute(InsertOp::getPositionAttrName(result.name), positionAttr);
|
||||
}
|
||||
|
||||
// Convenience builder which assumes the values are constant indices.
|
||||
void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
|
||||
Value dest, ValueRange position) {
|
||||
@@ -2307,25 +2285,24 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
|
||||
}
|
||||
|
||||
LogicalResult InsertOp::verify() {
|
||||
auto positionAttr = getPosition().getValue();
|
||||
ArrayRef<int64_t> position = getPosition();
|
||||
auto destVectorType = getDestVectorType();
|
||||
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
|
||||
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
|
||||
return emitOpError(
|
||||
"expected position attribute of rank no greater than dest vector rank");
|
||||
auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
|
||||
if (srcVectorType &&
|
||||
(static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
|
||||
(static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
|
||||
static_cast<unsigned>(destVectorType.getRank())))
|
||||
return emitOpError("expected position attribute rank + source rank to "
|
||||
"match dest vector rank");
|
||||
if (!srcVectorType &&
|
||||
(positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
|
||||
(position.size() != static_cast<unsigned>(destVectorType.getRank())))
|
||||
return emitOpError(
|
||||
"expected position attribute rank to match the dest vector rank");
|
||||
for (const auto &en : llvm::enumerate(positionAttr)) {
|
||||
auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
|
||||
if (!attr || attr.getInt() < 0 ||
|
||||
attr.getInt() >= destVectorType.getDimSize(en.index()))
|
||||
for (const auto &en : llvm::enumerate(position)) {
|
||||
int64_t attr = en.value();
|
||||
if (attr < 0 || attr >= destVectorType.getDimSize(en.index()))
|
||||
return emitOpError("expected position attribute #")
|
||||
<< (en.index() + 1)
|
||||
<< " to be a non-negative integer smaller than the corresponding "
|
||||
@@ -2412,7 +2389,7 @@ public:
|
||||
// Calculate the linearized position of the continuous chunk of elements to
|
||||
// insert.
|
||||
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
|
||||
copy(getI64SubArray(op.getPosition()), completePositions.begin());
|
||||
copy(op.getPosition(), completePositions.begin());
|
||||
int64_t insertBeginPosition =
|
||||
linearize(completePositions, computeStrides(destTy.getShape()));
|
||||
|
||||
|
||||
@@ -91,10 +91,8 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
|
||||
return val;
|
||||
Type lowType = VectorType::Builder(type).dropDim(0);
|
||||
// At extraction dimension?
|
||||
if (index == 0) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(pos);
|
||||
return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
|
||||
}
|
||||
if (index == 0)
|
||||
return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
|
||||
// Unroll leading dimensions.
|
||||
VectorType vType = cast<VectorType>(lowType);
|
||||
Type resType = VectorType::Builder(type).dropDim(index);
|
||||
@@ -102,11 +100,10 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, resVectorType, rewriter.getZeroAttr(resVectorType));
|
||||
for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
|
||||
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
|
||||
result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
|
||||
posAttr);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, d);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -120,20 +117,17 @@ static Value reshapeStore(Location loc, Value val, Value result,
|
||||
if (index == -1)
|
||||
return val;
|
||||
// At insertion dimension?
|
||||
if (index == 0) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(pos);
|
||||
return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
|
||||
}
|
||||
if (index == 0)
|
||||
return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
|
||||
// Unroll leading dimensions.
|
||||
Type lowType = VectorType::Builder(type).dropDim(0);
|
||||
VectorType vType = cast<VectorType>(lowType);
|
||||
Type insType = VectorType::Builder(vType).dropDim(0);
|
||||
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
|
||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
|
||||
Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, d);
|
||||
Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
|
||||
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
|
||||
result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
|
||||
result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -823,10 +817,8 @@ struct ContractOpToElementwise
|
||||
newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
|
||||
SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
|
||||
SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
|
||||
newLhs = rewriter.create<vector::ExtractOp>(
|
||||
loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
|
||||
newRhs = rewriter.create<vector::ExtractOp>(
|
||||
loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
|
||||
newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
|
||||
newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
|
||||
std::optional<Value> result =
|
||||
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
|
||||
contractOp.getKind(), rewriter, isInt);
|
||||
@@ -1167,21 +1159,20 @@ public:
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, resType, rewriter.getZeroAttr(resType));
|
||||
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
|
||||
auto pos = rewriter.getI64ArrayAttr(d);
|
||||
Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
|
||||
Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
|
||||
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
|
||||
Value r = nullptr;
|
||||
if (acc)
|
||||
r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
|
||||
r = rewriter.create<vector::ExtractOp>(loc, acc, d);
|
||||
Value extrMask;
|
||||
if (mask)
|
||||
extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
|
||||
extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
|
||||
|
||||
std::optional<Value> m = createContractArithOp(
|
||||
loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
|
||||
if (!m.has_value())
|
||||
return failure();
|
||||
result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
|
||||
result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, d);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(rootOp, result);
|
||||
|
||||
@@ -77,9 +77,7 @@ public:
|
||||
Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
bnd, idx);
|
||||
Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
|
||||
auto pos = rewriter.getI64ArrayAttr(d);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
|
||||
result = rewriter.create<vector::InsertOp>(loc, dstType, sel, result, d);
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
@@ -151,11 +149,9 @@ public:
|
||||
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, dstType, rewriter.getZeroAttr(dstType));
|
||||
for (int64_t d = 0; d < trueDim; d++) {
|
||||
auto pos = rewriter.getI64ArrayAttr(d);
|
||||
for (int64_t d = 0; d < trueDim; d++)
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
|
||||
}
|
||||
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, d);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -944,7 +944,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
// Rewrite vector.extract with 1d source to vector.extractelement.
|
||||
if (extractSrcType.getRank() == 1) {
|
||||
assert(extractOp.getPosition().size() == 1 && "expected 1 index");
|
||||
int64_t pos = cast<IntegerAttr>(extractOp.getPosition()[0]).getInt();
|
||||
int64_t pos = extractOp.getPosition()[0];
|
||||
rewriter.setInsertionPoint(extractOp);
|
||||
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
|
||||
extractOp, extractOp.getVector(),
|
||||
@@ -1201,7 +1201,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
// Rewrite vector.insert with 1d dest to vector.insertelement.
|
||||
if (insertOp.getDestVectorType().getRank() == 1) {
|
||||
assert(insertOp.getPosition().size() == 1 && "expected 1 index");
|
||||
int64_t pos = cast<IntegerAttr>(insertOp.getPosition()[0]).getInt();
|
||||
int64_t pos = insertOp.getPosition()[0];
|
||||
rewriter.setInsertionPoint(insertOp);
|
||||
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
|
||||
insertOp, insertOp.getSource(), insertOp.getDest(),
|
||||
@@ -1276,10 +1276,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
} else {
|
||||
// One lane inserts the entire source vector.
|
||||
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
|
||||
SmallVector<int64_t> newPos = llvm::to_vector(
|
||||
llvm::map_range(insertOp.getPosition(), [](Attribute attr) {
|
||||
return cast<IntegerAttr>(attr).getInt();
|
||||
}));
|
||||
SmallVector<int64_t> newPos(insertOp.getPosition());
|
||||
// tid of inserting lane: pos / elementsPerLane
|
||||
Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
|
||||
loc, newPos[distrDestDim] / elementsPerLane);
|
||||
|
||||
@@ -165,16 +165,14 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
|
||||
// type has leading unit dims, we also trim the position array accordingly,
|
||||
// then (2) if source type also has leading unit dims, we need to append
|
||||
// zeroes to the position array accordingly.
|
||||
unsigned oldPosRank = insertOp.getPosition().getValue().size();
|
||||
unsigned oldPosRank = insertOp.getPosition().size();
|
||||
unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
|
||||
SmallVector<Attribute> newPositions = llvm::to_vector(
|
||||
insertOp.getPosition().getValue().take_back(newPosRank));
|
||||
newPositions.resize(newDstType.getRank() - newSrcRank,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
SmallVector<int64_t> newPositions =
|
||||
llvm::to_vector(insertOp.getPosition().take_back(newPosRank));
|
||||
newPositions.resize(newDstType.getRank() - newSrcRank, 0);
|
||||
|
||||
auto newInsertOp = rewriter.create<vector::InsertOp>(
|
||||
loc, newDstType, newSrcVector, newDstVector,
|
||||
rewriter.getArrayAttr(newPositions));
|
||||
loc, newDstType, newSrcVector, newDstVector, newPositions);
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
|
||||
newInsertOp);
|
||||
|
||||
@@ -704,7 +704,7 @@ class RewriteScalarExtractOfTransferRead
|
||||
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
|
||||
xferOp.getIndices().end());
|
||||
for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
|
||||
int64_t offset = cast<IntegerAttr>(it.value()).getInt();
|
||||
int64_t offset = it.value();
|
||||
int64_t idx =
|
||||
newIndices.size() - extractOp.getPosition().size() + it.index();
|
||||
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
|
||||
|
||||
@@ -598,11 +598,7 @@ struct BubbleDownVectorBitCastForExtract
|
||||
unsigned expandRatio =
|
||||
castDstType.getNumElements() / castSrcType.getNumElements();
|
||||
|
||||
auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
|
||||
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
|
||||
};
|
||||
|
||||
uint64_t index = getFirstIntValue(extractOp.getPosition());
|
||||
uint64_t index = extractOp.getPosition()[0];
|
||||
|
||||
// Get the single scalar (as a vector) in the source value that packs the
|
||||
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
|
||||
@@ -610,7 +606,7 @@ struct BubbleDownVectorBitCastForExtract
|
||||
VectorType::get({1}, castSrcType.getElementType());
|
||||
Value packedValue = rewriter.create<vector::ExtractOp>(
|
||||
extractOp.getLoc(), oneScalarType, castOp.getSource(),
|
||||
rewriter.getI64ArrayAttr(index / expandRatio));
|
||||
index / expandRatio);
|
||||
|
||||
// Cast it to a vector with the desired scalar's type.
|
||||
// E.g. f32 -> vector<2xf16>
|
||||
@@ -621,8 +617,7 @@ struct BubbleDownVectorBitCastForExtract
|
||||
|
||||
// Finally extract the desired scalar.
|
||||
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
|
||||
extractOp, extractOp.getType(), castedValue,
|
||||
rewriter.getI64ArrayAttr(index % expandRatio));
|
||||
extractOp, extractOp.getType(), castedValue, index % expandRatio);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -155,8 +155,8 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
|
||||
// CHECK: spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
|
||||
// CHECK: spirv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
|
||||
func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
|
||||
%0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32>
|
||||
%1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32
|
||||
%0 = "vector.extract"(%arg0) <{position = array<i64: 0>}> : (vector<2xf32>) -> vector<1xf32>
|
||||
%1 = "vector.extract"(%arg0) <{position = array<i64: 1>}> : (vector<2xf32>) -> f32
|
||||
return %0, %1: vector<1xf32>, f32
|
||||
}
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
|
||||
|
||||
func.func @extract_position_rank_overflow_generic(%arg0: vector<4x8x16xf32>) {
|
||||
// expected-error@+1 {{expected position attribute of rank no greater than vector rank}}
|
||||
%1 = "vector.extract" (%arg0) { position = [0, 0, 0, 0] } : (vector<4x8x16xf32>) -> (vector<16xf32>)
|
||||
%1 = "vector.extract" (%arg0) <{position = array<i64: 0, 0, 0, 0>}> : (vector<4x8x16xf32>) -> (vector<16xf32>)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Reference in New Issue
Block a user