Files
llvm/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Andrzej Warzyński 2ee5586ac7 [mlir][vector] Make the in_bounds attribute mandatory (#97049)
At the moment, the in_bounds attribute has two confusing/contradicting
properties:
  1. It is both optional _and_ has an effective default-value.
  2. The default value is "out-of-bounds" for non-broadcast dims, and
     "in-bounds" for broadcast dims.

(see the `isDimInBounds` vector interface method for an example of this
"default" behaviour [1]).

This PR aims to clarify the logic surrounding the `in_bounds` attribute
by:
  * making the attribute mandatory (i.e. it is always present),
  * always setting the default value to "out of bounds" (that's
    consistent with the current behaviour for the most common cases).

#### Broadcast dimensions in tests

As per [2], the broadcast dimensions requires the corresponding
`in_bounds` attribute to be `true`:
```
  vector.transfer_read op requires broadcast dimensions to be in-bounds
```

The changes in this PR mean that we can no longer rely on the
default value in cases like the following (dim 0 is a broadcast dim):
```mlir
  %read = vector.transfer_read %A[%base1, %base2], %f, %mask
      {permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
    memref<?x?xf32>, vector<4x9xf32>
```

Instead, the broadcast dimension has to explicitly be marked as "in
bounds:

```mlir
  %read = vector.transfer_read %A[%base1, %base2], %f, %mask
      {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
    memref<?x?xf32>, vector<4x9xf32>
```

All tests with broadcast dims are updated accordingly.

#### Changes in "SuperVectorize.cpp" and "Vectorization.cpp"

The following patterns in "Vectorization.cpp" are updated to explicitly
set the `in_bounds` attribute to `false`:
* `LinalgCopyVTRForwardingPattern` and `LinalgCopyVTWForwardingPattern`

Also, `vectorizeAffineLoad` (from "SuperVectorize.cpp") and
`vectorizeAsLinalgGeneric` (from "Vectorization.cpp") are updated to
make sure that xfer Ops created by these hooks set the dimension
corresponding to broadcast dims as "in bounds". Otherwise, the Op
verifier would complain

Note that there is no mechanism to verify whether the corresponding
memory access are indeed in bounds. Still, this is consistent with the
current behaviour where the broadcast dim would be implicitly assumed
to be "in bounds".

[1]
4145ad2bac/mlir/include/mlir/Interfaces/VectorInterfaces.td (L243-L246)
[2]
https://mlir.llvm.org/docs/Dialects/Vector/#vectortransfer_read-vectortransferreadop
2024-07-16 16:49:52 +01:00

674 lines
29 KiB
C++

//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewrite patterns for the permutation_map attribute of
// vector.transfer operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Interfaces/VectorInterfaces.h"
using namespace mlir;
using namespace mlir::vector;
/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
/// permutation based on the given indices.
static ArrayAttr
inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
const SmallVector<unsigned> &permutation) {
SmallVector<bool> newInBoundsValues(permutation.size());
size_t index = 0;
for (unsigned pos : permutation)
newInBoundsValues[pos] =
cast<BoolAttr>(attr.getValue()[index++]).getValue();
return builder.getBoolArrayAttr(newInBoundsValues);
}
/// Extend the rank of a vector Value by `addedRanks` by adding outer unit
/// dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
int64_t addedRank) {
auto originalVecType = cast<VectorType>(vec.getType());
SmallVector<int64_t> newShape(addedRank, 1);
newShape.append(originalVecType.getShape().begin(),
originalVecType.getShape().end());
SmallVector<bool> newScalableDims(addedRank, false);
newScalableDims.append(originalVecType.getScalableDims().begin(),
originalVecType.getScalableDims().end());
VectorType newVecType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
}
/// Extend the rank of a vector Value by `addedRanks` by adding inner unit
/// dimensions.
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
int64_t addedRank) {
Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
SmallVector<int64_t> permutation;
for (int64_t i = addedRank,
e = cast<VectorType>(broadcasted.getType()).getRank();
i < e; ++i)
permutation.push_back(i);
for (int64_t i = 0; i < addedRank; ++i)
permutation.push_back(i);
return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
}
//===----------------------------------------------------------------------===//
// populateVectorTransferPermutationMapLoweringPatterns
//===----------------------------------------------------------------------===//
namespace {
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identiy +
/// vector.transpose op.
/// Ex:
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (0, d1)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (d1, 0)
/// vector.transpose %v, [1, 0]
///
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
/// vector.transpose %v, [0, 1, 3, 2, 4]
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
struct TransferReadPermutationLowering
: public MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
// TODO: Support transfer_read inside MaskOp case.
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
if (map.getNumResults() == 0)
return rewriter.notifyMatchFailure(op, "0 result permutation map");
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
return rewriter.notifyMatchFailure(
op, "map is not permutable to minor identity, apply another pattern");
}
AffineMap permutationMap =
map.getPermutationMap(permutation, op.getContext());
if (permutationMap.isIdentity())
return rewriter.notifyMatchFailure(op, "map is not identity");
permutationMap = map.getPermutationMap(permutation, op.getContext());
// Caluclate the map of the new read by applying the inverse permutation.
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
// Apply the reverse transpose to deduce the type of the transfer_read.
ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
SmallVector<int64_t> newVectorShape(originalShape.size());
ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
SmallVector<bool> newScalableDims(originalShape.size());
for (const auto &pos : llvm::enumerate(permutation)) {
newVectorShape[pos.value()] = originalShape[pos.index()];
newScalableDims[pos.value()] = originalScalableDims[pos.index()];
}
// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
// Generate new transfer_read operation.
VectorType newReadType = VectorType::get(
newVectorShape, op.getVectorType().getElementType(), newScalableDims);
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
return rewriter
.create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
.getResult();
}
};
/// Lower transfer_write op with permutation into a transfer_write with a
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
/// Ex:
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
/// into:
/// %tmp = vector.transpose %v, [2, 0, 1]
/// vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
///
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
/// into:
/// %tmp = vector.transpose %v, [1, 0]
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
struct TransferWritePermutationLowering
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
// TODO: Support transfer_write inside MaskOp case.
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
if (map.isMinorIdentity())
return rewriter.notifyMatchFailure(op, "map is already minor identity");
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
return rewriter.notifyMatchFailure(
op, "map is not permutable to minor identity, apply another pattern");
}
// Remove unused dims from the permutation map. E.g.:
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
// comp = (d0, d1, d2) -> (d2, d0, d1)
auto comp = compressUnusedDims(map);
AffineMap permutationMap = inversePermutation(comp);
// Get positions of remaining result dims.
SmallVector<int64_t> indices;
llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
[](AffineExpr expr) {
return dyn_cast<AffineDimExpr>(expr).getPosition();
});
// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
// Generate new transfer_write operation.
Value newVec = rewriter.create<vector::TransposeOp>(
op.getLoc(), op.getVector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
auto newWrite = rewriter.create<vector::TransferWriteOp>(
op.getLoc(), newVec, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
// In the memref case there's no return value. Use empty value to signal
// success.
return Value();
}
};
/// Convert a transfer.write op with a map which isn't the permutation of a
/// minor identity into a vector.broadcast + transfer_write with permutation of
/// minor identity map by adding unit dim on inner dimension. Ex:
/// ```
/// vector.transfer_write %v
/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
/// vector<8x16xf32>
/// ```
/// into:
/// ```
/// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
/// vector.transfer_write %v1
/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
/// vector<1x8x16xf32>
/// ```
struct TransferWriteNonPermutationLowering
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
// TODO: Support transfer_write inside MaskOp case.
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
return rewriter.notifyMatchFailure(
op,
"map is already permutable to minor identity, apply another pattern");
}
// Missing outer dimensions are allowed, find the most outer existing
// dimension then deduce the missing inner dimensions.
SmallVector<bool> foundDim(map.getNumDims(), false);
for (AffineExpr exp : map.getResults())
foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
SmallVector<AffineExpr> exprs;
bool foundFirstDim = false;
SmallVector<int64_t> missingInnerDim;
for (size_t i = 0; i < foundDim.size(); i++) {
if (foundDim[i]) {
foundFirstDim = true;
continue;
}
if (!foundFirstDim)
continue;
// Once we found one outer dimension existing in the map keep track of all
// the missing dimensions after that.
missingInnerDim.push_back(i);
exprs.push_back(rewriter.getAffineDimExpr(i));
}
// Vector: add unit dims at the beginning of the shape.
Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
missingInnerDim.size());
// Mask: add unit dims at the end of the shape.
Value newMask;
if (op.getMask())
newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
missingInnerDim.size());
exprs.append(map.getResults().begin(), map.getResults().end());
AffineMap newMap =
AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
// All the new dimensions added are inbound.
SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
newInBoundsValues.push_back(op.isDimInBounds(i));
}
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
auto newWrite = rewriter.create<vector::TransferWriteOp>(
op.getLoc(), newVec, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
// In the memref case there's no return value. Use empty value to signal
// success.
return Value();
}
};
/// Lower transfer_read op with broadcast in the leading dimensions into
/// transfer_read of lower rank + vector.broadcast.
/// Ex: vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
struct TransferOpReduceRank
: public MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
// TODO: support masked case.
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
AffineMap map = op.getPermutationMap();
unsigned numLeadingBroadcast = 0;
for (auto expr : map.getResults()) {
auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
if (!dimExpr || dimExpr.getValue() != 0)
break;
numLeadingBroadcast++;
}
// If there are no leading zeros in the map there is nothing to do.
if (numLeadingBroadcast == 0)
return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
VectorType originalVecType = op.getVectorType();
unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
// Calculate new map, vector type and masks without the leading zeros.
AffineMap newMap = AffineMap::get(
map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
op.getContext());
// Only remove the leading zeros if the rest of the map is a minor identity
// with broadasting. Otherwise we first want to permute the map.
if (!newMap.isMinorIdentityWithBroadcasting()) {
return rewriter.notifyMatchFailure(
op, "map is not a minor identity with broadcasting");
}
// TODO: support zero-dimension vectors natively. See:
// https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
// In the meantime, lower these to a scalar load when they pop up.
if (reducedShapeRank == 0) {
Value newRead;
if (isa<TensorType>(op.getShapedType())) {
newRead = rewriter.create<tensor::ExtractOp>(
op.getLoc(), op.getSource(), op.getIndices());
} else {
newRead = rewriter.create<memref::LoadOp>(
op.getLoc(), originalVecType.getElementType(), op.getSource(),
op.getIndices());
}
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}
SmallVector<int64_t> newShape(
originalVecType.getShape().take_back(reducedShapeRank));
SmallVector<bool> newScalableDims(
originalVecType.getScalableDims().take_back(reducedShapeRank));
// Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
if (newShape.empty())
return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");
VectorType newReadType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
ArrayAttr newInBoundsAttr =
op.getInBounds()
? rewriter.getArrayAttr(
op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
: ArrayAttr();
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}
};
} // namespace
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<TransferReadPermutationLowering, TransferWritePermutationLowering,
TransferOpReduceRank, TransferWriteNonPermutationLowering>(
patterns.getContext(), benefit);
}
//===----------------------------------------------------------------------===//
// populateVectorTransferLoweringPatterns
//===----------------------------------------------------------------------===//
namespace {
/// Progressive lowering of transfer_read. This pattern supports lowering of
/// `vector.transfer_read` to a combination of `vector.load` and
/// `vector.broadcast` if all of the following hold:
/// - Stride of most minor memref dimension must be 1.
/// - Out-of-bounds masking is not required.
/// - If the memref's element type is a vector type then it coincides with the
/// result type.
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
struct TransferReadToVectorLoadLowering
: public MaskableOpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
: MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
maxTransferRank(maxRank) {}
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp read,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
return rewriter.notifyMatchFailure(
read, "vector type is greater than max transfer rank");
}
if (maskOp)
return rewriter.notifyMatchFailure(read, "Masked case not supported");
SmallVector<unsigned> broadcastedDims;
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
// We let the 0-d corner case pass-through as it is supported.
if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
&broadcastedDims))
return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
if (!memRefType)
return rewriter.notifyMatchFailure(read, "not a memref source");
// Non-unit strides are handled by VectorToSCF.
if (!isLastMemrefDimUnitStride(memRefType))
return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
// If there is broadcasting involved then we first load the unbroadcasted
// vector, and then broadcast it with `vector.broadcast`.
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
vectorShape.end());
for (unsigned i : broadcastedDims)
unbroadcastedVectorShape[i] = 1;
VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
unbroadcastedVectorShape, read.getVectorType().getElementType());
// `vector.load` supports vector types as memref's elements only when the
// resulting vector type is the same as the element type.
auto memrefElTy = memRefType.getElementType();
if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
return rewriter.notifyMatchFailure(read, "incompatible element type");
// Otherwise, element types of the memref and the vector must match.
if (!isa<VectorType>(memrefElTy) &&
memrefElTy != read.getVectorType().getElementType())
return rewriter.notifyMatchFailure(read, "non-matching element type");
// Out-of-bounds dims are handled by MaterializeTransferMask.
if (read.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
// Create vector load op.
Operation *res;
if (read.getMask()) {
if (read.getVectorType().getRank() != 1)
// vector.maskedload operates on 1-D vectors.
return rewriter.notifyMatchFailure(
read, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");
Value fill = rewriter.create<vector::SplatOp>(
read.getLoc(), unbroadcastedVectorType, read.getPadding());
res = rewriter.create<vector::MaskedLoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices(), read.getMask(), fill);
} else {
res = rewriter.create<vector::LoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices());
}
// Insert a broadcasting op if required.
if (!broadcastedDims.empty())
res = rewriter.create<vector::BroadcastOp>(
read.getLoc(), read.getVectorType(), res->getResult(0));
return res->getResult(0);
}
std::optional<unsigned> maxTransferRank;
};
/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
// TODO: we shouldn't cross the vector/scalar domains just for this
// but atm we lack the infra to avoid it. Possible solutions include:
// - go directly to LLVM + bitcast
// - introduce a bitcast op and likely a new pointer dialect
// - let memref.load/store additionally support the 0-d vector case
// There are still deeper data layout issues lingering even in this
// trivial case (for architectures for which this matters).
struct VectorLoadToMemrefLoadLowering
: public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
auto vecType = loadOp.getVectorType();
if (vecType.getNumElements() != 1)
return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
auto memrefLoad = rewriter.create<memref::LoadOp>(
loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
memrefLoad);
return success();
}
};
/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
struct VectorStoreToMemrefStoreLowering
: public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const override {
auto vecType = storeOp.getVectorType();
if (vecType.getNumElements() != 1)
return rewriter.notifyMatchFailure(storeOp, "not single element vector");
Value extracted;
if (vecType.getRank() == 0) {
// TODO: Unifiy once ExtractOp supports 0-d vectors.
extracted = rewriter.create<vector::ExtractElementOp>(
storeOp.getLoc(), storeOp.getValueToStore());
} else {
SmallVector<int64_t> indices(vecType.getRank(), 0);
extracted = rewriter.create<vector::ExtractOp>(
storeOp.getLoc(), storeOp.getValueToStore(), indices);
}
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
return success();
}
};
/// Progressive lowering of transfer_write. This pattern supports lowering of
/// `vector.transfer_write` to `vector.store` if all of the following hold:
/// - Stride of most minor memref dimension must be 1.
/// - Out-of-bounds masking is not required.
/// - If the memref's element type is a vector type then it coincides with the
/// type of the written value.
/// - The permutation map is the minor identity map (neither permutation nor
/// broadcasting is allowed).
struct TransferWriteToVectorStoreLowering
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
: MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
maxTransferRank(maxRank) {}
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp write,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
return rewriter.notifyMatchFailure(
write, "vector type is greater than max transfer rank");
}
if (maskOp)
return rewriter.notifyMatchFailure(write, "Masked case not supported");
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
if ( // pass-through for the 0-d corner case.
!write.getPermutationMap().isMinorIdentity())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "permutation map is not minor identity: " << write;
});
auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
if (!memRefType)
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "not a memref type: " << write;
});
// Non-unit strides are handled by VectorToSCF.
if (!isLastMemrefDimUnitStride(memRefType))
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "most minor stride is not 1: " << write;
});
// `vector.store` supports vector types as memref's elements only when the
// type of the vector value being written is the same as the element type.
auto memrefElTy = memRefType.getElementType();
if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
});
// Otherwise, element types of the memref and the vector must match.
if (!isa<VectorType>(memrefElTy) &&
memrefElTy != write.getVectorType().getElementType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
});
// Out-of-bounds dims are handled by MaterializeTransferMask.
if (write.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "out of bounds dim: " << write;
});
if (write.getMask()) {
if (write.getVectorType().getRank() != 1)
// vector.maskedstore operates on 1-D vectors.
return rewriter.notifyMatchFailure(
write.getLoc(), [=](Diagnostic &diag) {
diag << "vector type is not rank 1, can't create masked store, "
"needs VectorToSCF: "
<< write;
});
rewriter.create<vector::MaskedStoreOp>(
write.getLoc(), write.getSource(), write.getIndices(),
write.getMask(), write.getVector());
} else {
rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
write.getSource(), write.getIndices());
}
// There's no return value for StoreOps. Use Value() to signal success to
// matchAndRewrite.
return Value();
}
std::optional<unsigned> maxTransferRank;
};
} // namespace
void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
PatternBenefit benefit) {
patterns.add<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(patterns.getContext(),
maxTransferRank, benefit);
patterns
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
patterns.getContext(), benefit);
}