mirror of
https://github.com/intel/llvm.git
synced 2026-01-23 16:06:39 +08:00
[mlir][vector] Remove vector.transfer_read/write to LLVM lowering
This simplifies the vector to LLVM lowering. Previously, both vector.load/store and vector.transfer_read/write lowered directly to LLVM. With this commit, there is a single path to LLVM vector load/store instructions and vector.transfer_read/write ops must first be lowered to vector.load/store ops. * Remove vector.transfer_read/write to LLVM lowering. * Allow non-unit memref strides on all but the most minor dimension for vector.load/store ops. * Add maxTransferRank option to populateVectorTransferLoweringPatterns. * vector.transfer_reads with changing element type can no longer be lowered to LLVM. (This functionality is needed only for SPIRV.) Differential Revision: https://reviews.llvm.org/D106118
This commit is contained in:
@@ -62,9 +62,12 @@ void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns);
|
||||
/// Collect a set of transfer read/write lowering patterns.
|
||||
///
|
||||
/// These patterns lower transfer ops to simpler ops like `vector.load`,
|
||||
/// `vector.store` and `vector.broadcast`. Includes all patterns of
|
||||
/// populateVectorTransferPermutationMapLoweringPatterns.
|
||||
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
|
||||
/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
|
||||
/// of a most `maxTransferRank` are lowered. This is useful when combined with
|
||||
/// VectorToSCF, which reduces the rank of vector transfer ops.
|
||||
void populateVectorTransferLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
llvm::Optional<unsigned> maxTransferRank = llvm::None);
|
||||
|
||||
/// Collect a set of transfer read/write lowering patterns that simplify the
|
||||
/// permutation map (e.g., converting it to a minor identity map) by inserting
|
||||
@@ -185,6 +188,10 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
|
||||
Value getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
|
||||
Value vector);
|
||||
|
||||
/// Return true if the last dimension of the MemRefType has unit stride. Also
|
||||
/// return true for memrefs with no strides.
|
||||
bool isLastMemrefDimUnitStride(MemRefType type);
|
||||
|
||||
namespace impl {
|
||||
/// Build the default minor identity map suitable for a vector transfer. This
|
||||
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
|
||||
|
||||
@@ -1409,9 +1409,9 @@ def Vector_LoadOp : Vector_Op<"load"> {
|
||||
based on the element type of the memref. The shape of the result vector
|
||||
type determines the shape of the slice read from the start memory address.
|
||||
The elements along each dimension of the slice are strided by the memref
|
||||
strides. Only memref with default strides are allowed. These constraints
|
||||
guarantee that elements read along the first dimension of the slice are
|
||||
contiguous in memory.
|
||||
strides. Only unit strides are allowed along the most minor memref
|
||||
dimension. These constraints guarantee that elements read along the first
|
||||
dimension of the slice are contiguous in memory.
|
||||
|
||||
The memref element type can be a scalar or a vector type. If the memref
|
||||
element type is a scalar, it should match the element type of the result
|
||||
@@ -1470,6 +1470,8 @@ def Vector_LoadOp : Vector_Op<"load"> {
|
||||
}
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat =
|
||||
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
|
||||
}
|
||||
@@ -1484,9 +1486,9 @@ def Vector_StoreOp : Vector_Op<"store"> {
|
||||
memref dimension based on the element type of the memref. The shape of the
|
||||
vector value to store determines the shape of the slice written from the
|
||||
start memory address. The elements along each dimension of the slice are
|
||||
strided by the memref strides. Only memref with default strides are allowed.
|
||||
These constraints guarantee that elements written along the first dimension
|
||||
of the slice are contiguous in memory.
|
||||
strided by the memref strides. Only unit strides are allowed along the most
|
||||
minor memref dimension. These constraints guarantee that elements written
|
||||
along the first dimension of the slice are contiguous in memory.
|
||||
|
||||
The memref element type can be a scalar or a vector type. If the memref
|
||||
element type is a scalar, it should match the element type of the value
|
||||
@@ -1544,6 +1546,8 @@ def Vector_StoreOp : Vector_Op<"store"> {
|
||||
}
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
|
||||
"`:` type($base) `,` type($valueToStore)";
|
||||
}
|
||||
@@ -1601,6 +1605,7 @@ def Vector_MaskedLoadOp :
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
|
||||
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Vector_MaskedStoreOp :
|
||||
@@ -1653,6 +1658,7 @@ def Vector_MaskedStoreOp :
|
||||
"$base `[` $indices `]` `,` $mask `,` $valueToStore "
|
||||
"attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Vector_GatherOp :
|
||||
|
||||
@@ -130,18 +130,6 @@ static unsigned getAssumedAlignment(Value value) {
|
||||
}
|
||||
return align;
|
||||
}
|
||||
// Helper that returns data layout alignment of a memref associated with a
|
||||
// transfer op, including additional information from assume_alignment calls
|
||||
// on the source of the transfer
|
||||
LogicalResult getTransferOpAlignment(LLVMTypeConverter &typeConverter,
|
||||
VectorTransferOpInterface xfer,
|
||||
unsigned &align) {
|
||||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xfer.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
align = std::max(align, getAssumedAlignment(xfer.source()));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns data layout alignment of a memref associated with a
|
||||
// load, store, scatter, or gather op, including additional information from
|
||||
@@ -181,79 +169,6 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
TransferReadOp xferOp,
|
||||
ArrayRef<Value> operands, Value dataPtr) {
|
||||
unsigned align;
|
||||
if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
TransferReadOp xferOp, ArrayRef<Value> operands,
|
||||
Value dataPtr, Value mask) {
|
||||
Type vecTy = typeConverter.convertType(xferOp.getVectorType());
|
||||
if (!vecTy)
|
||||
return failure();
|
||||
|
||||
auto adaptor = TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
Value fill = rewriter.create<SplatOp>(loc, vecTy, adaptor.padding());
|
||||
|
||||
unsigned align;
|
||||
if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
xferOp, vecTy, dataPtr, mask, ValueRange{fill},
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
TransferWriteOp xferOp,
|
||||
ArrayRef<Value> operands, Value dataPtr) {
|
||||
unsigned align;
|
||||
if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
|
||||
return failure();
|
||||
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
|
||||
align);
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
TransferWriteOp xferOp, ArrayRef<Value> operands,
|
||||
Value dataPtr, Value mask) {
|
||||
unsigned align;
|
||||
if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
|
||||
return failure();
|
||||
|
||||
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
xferOp, adaptor.vector(), dataPtr, mask,
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
|
||||
static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
|
||||
ArrayRef<Value> operands) {
|
||||
return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
}
|
||||
|
||||
static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
|
||||
ArrayRef<Value> operands) {
|
||||
return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Conversion pattern for a vector.bitcast.
|
||||
@@ -1026,15 +941,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Return true if the last dimension of the MemRefType has unit stride. Also
|
||||
/// return true for memrefs with no strides.
|
||||
static bool isLastMemrefDimUnitStride(MemRefType type) {
|
||||
int64_t offset;
|
||||
SmallVector<int64_t> strides;
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
|
||||
}
|
||||
|
||||
/// Returns the strides if the memory underlying `memRefType` has a contiguous
|
||||
/// static layout.
|
||||
static llvm::Optional<SmallVector<int64_t, 4>>
|
||||
@@ -1145,83 +1051,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern that converts a 1-D vector transfer read/write op into a
|
||||
/// a masked or unmasked read/write.
|
||||
template <typename ConcreteOp>
|
||||
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = getTransferOpAdapter(xferOp, operands);
|
||||
|
||||
if (xferOp.getVectorType().getRank() > 1 || xferOp.indices().empty())
|
||||
return failure();
|
||||
if (xferOp.permutation_map() !=
|
||||
AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
|
||||
xferOp.getVectorType().getRank(),
|
||||
xferOp->getContext()))
|
||||
return failure();
|
||||
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
// Last dimension must be contiguous. (Otherwise: Use VectorToSCF.)
|
||||
if (!isLastMemrefDimUnitStride(memRefType))
|
||||
return failure();
|
||||
// Out-of-bounds dims are handled by MaterializeTransferMask.
|
||||
if (xferOp.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
|
||||
auto toLLVMTy = [&](Type t) {
|
||||
return this->getTypeConverter()->convertType(t);
|
||||
};
|
||||
|
||||
Location loc = xferOp->getLoc();
|
||||
|
||||
if (auto memrefVectorElementType =
|
||||
memRefType.getElementType().template dyn_cast<VectorType>()) {
|
||||
// Memref has vector element type.
|
||||
if (memrefVectorElementType.getElementType() !=
|
||||
xferOp.getVectorType().getElementType())
|
||||
return failure();
|
||||
#ifndef NDEBUG
|
||||
// Check that memref vector type is a suffix of 'vectorType.
|
||||
unsigned memrefVecEltRank = memrefVectorElementType.getRank();
|
||||
unsigned resultVecRank = xferOp.getVectorType().getRank();
|
||||
assert(memrefVecEltRank <= resultVecRank);
|
||||
// TODO: Move this to isSuffix in Vector/Utils.h.
|
||||
unsigned rankOffset = resultVecRank - memrefVecEltRank;
|
||||
auto memrefVecEltShape = memrefVectorElementType.getShape();
|
||||
auto resultVecShape = xferOp.getVectorType().getShape();
|
||||
for (unsigned i = 0; i < memrefVecEltRank; ++i)
|
||||
assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
|
||||
"memref vector element shape should match suffix of vector "
|
||||
"result shape.");
|
||||
#endif // ifndef NDEBUG
|
||||
}
|
||||
|
||||
// Get the source/dst address as an LLVM vector pointer.
|
||||
VectorType vtp = xferOp.getVectorType();
|
||||
Value dataPtr = this->getStridedElementPtr(
|
||||
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
|
||||
Value vectorDataPtr =
|
||||
castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
|
||||
|
||||
// Rewrite as an unmasked masked read / write.
|
||||
if (!xferOp.mask())
|
||||
return replaceTransferOpWithLoadOrStore(rewriter,
|
||||
*this->getTypeConverter(), loc,
|
||||
xferOp, operands, vectorDataPtr);
|
||||
|
||||
// Rewrite as a masked read / write.
|
||||
return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
|
||||
xferOp, operands, vectorDataPtr,
|
||||
xferOp.mask());
|
||||
}
|
||||
};
|
||||
|
||||
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
|
||||
@@ -1450,9 +1279,10 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorLoadStoreConversion<vector::MaskedStoreOp,
|
||||
vector::MaskedStoreOpAdaptor>,
|
||||
VectorGatherOpConversion, VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
|
||||
VectorTransferConversion<TransferReadOp>,
|
||||
VectorTransferConversion<TransferWriteOp>>(converter);
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
|
||||
converter);
|
||||
// Transfer ops with rank > 1 are handled by VectorToSCF.
|
||||
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
|
||||
@@ -64,6 +64,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
|
||||
populateVectorToVectorCanonicalizationPatterns(patterns);
|
||||
populateVectorContractLoweringPatterns(patterns);
|
||||
populateVectorTransposeLoweringPatterns(patterns);
|
||||
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
|
||||
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
@@ -71,6 +73,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations);
|
||||
populateVectorTransferLoweringPatterns(patterns);
|
||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
||||
populateVectorToLLVMConversionPatterns(converter, patterns,
|
||||
reassociateFPReductions);
|
||||
|
||||
@@ -89,7 +89,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
|
||||
.add<ContractionOpToOuterProductOpLowering,
|
||||
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
|
||||
vectorTransformsOptions, context);
|
||||
vector::populateVectorTransferLoweringPatterns(
|
||||
vector::populateVectorTransferPermutationMapLoweringPatterns(
|
||||
vectorContractLoweringPatterns);
|
||||
(void)applyPatternsAndFoldGreedily(
|
||||
func, std::move(vectorContractLoweringPatterns));
|
||||
|
||||
@@ -102,6 +102,15 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Return true if the last dimension of the MemRefType has unit stride. Also
|
||||
/// return true for memrefs with no strides.
|
||||
bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
|
||||
int64_t offset;
|
||||
SmallVector<int64_t> strides;
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CombiningKindAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -2953,9 +2962,8 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
|
||||
static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
|
||||
MemRefType memRefTy) {
|
||||
auto affineMaps = memRefTy.getAffineMaps();
|
||||
if (!affineMaps.empty())
|
||||
return op->emitOpError("base memref should have a default identity layout");
|
||||
if (!isLastMemrefDimUnitStride(memRefTy))
|
||||
return op->emitOpError("most minor memref dim must have unit stride");
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -2981,6 +2989,12 @@ static LogicalResult verify(vector::LoadOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -3008,6 +3022,11 @@ static LogicalResult verify(vector::StoreOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaskedLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -3056,6 +3075,12 @@ void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
results.add<MaskedLoadFolder>(context);
|
||||
}
|
||||
|
||||
OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaskedStoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -3101,6 +3126,11 @@ void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
results.add<MaskedStoreFolder>(context);
|
||||
}
|
||||
|
||||
LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -2464,26 +2464,34 @@ struct TransferWriteInsertPattern
|
||||
/// Progressive lowering of transfer_read. This pattern supports lowering of
|
||||
/// `vector.transfer_read` to a combination of `vector.load` and
|
||||
/// `vector.broadcast` if all of the following hold:
|
||||
/// - The op reads from a memref with the default layout.
|
||||
/// - Stride of most minor memref dimension must be 1.
|
||||
/// - Out-of-bounds masking is not required.
|
||||
/// - If the memref's element type is a vector type then it coincides with the
|
||||
/// result type.
|
||||
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
|
||||
/// - The op has no mask.
|
||||
struct TransferReadToVectorLoadLowering
|
||||
: public OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
||||
TransferReadToVectorLoadLowering(MLIRContext *context,
|
||||
llvm::Optional<unsigned> maxRank)
|
||||
: OpRewritePattern<vector::TransferReadOp>(context),
|
||||
maxTransferRank(maxRank) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp read,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
|
||||
return failure();
|
||||
SmallVector<unsigned, 4> broadcastedDims;
|
||||
// TODO: Support permutations.
|
||||
// Permutations are handled by VectorToSCF or
|
||||
// populateVectorTransferPermutationMapLoweringPatterns.
|
||||
if (!read.permutation_map().isMinorIdentityWithBroadcasting(
|
||||
&broadcastedDims))
|
||||
return failure();
|
||||
auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
// Non-unit strides are handled by VectorToSCF.
|
||||
if (!vector::isLastMemrefDimUnitStride(memRefType))
|
||||
return failure();
|
||||
|
||||
// If there is broadcasting involved then we first load the unbroadcasted
|
||||
// vector, and then broadcast it with `vector.broadcast`.
|
||||
@@ -2497,32 +2505,44 @@ struct TransferReadToVectorLoadLowering
|
||||
|
||||
// `vector.load` supports vector types as memref's elements only when the
|
||||
// resulting vector type is the same as the element type.
|
||||
if (memRefType.getElementType().isa<VectorType>() &&
|
||||
memRefType.getElementType() != unbroadcastedVectorType)
|
||||
auto memrefElTy = memRefType.getElementType();
|
||||
if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
|
||||
return failure();
|
||||
// Only the default layout is supported by `vector.load`.
|
||||
// TODO: Support non-default layouts.
|
||||
if (!memRefType.getAffineMaps().empty())
|
||||
return failure();
|
||||
// TODO: When out-of-bounds masking is required, we can create a
|
||||
// MaskedLoadOp.
|
||||
if (read.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
if (read.mask())
|
||||
// Otherwise, element types of the memref and the vector must match.
|
||||
if (!memrefElTy.isa<VectorType>() &&
|
||||
memrefElTy != read.getVectorType().getElementType())
|
||||
return failure();
|
||||
|
||||
auto loadOp = rewriter.create<vector::LoadOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.source(), read.indices());
|
||||
// Out-of-bounds dims are handled by MaterializeTransferMask.
|
||||
if (read.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
|
||||
// Create vector load op.
|
||||
Operation *loadOp;
|
||||
if (read.mask()) {
|
||||
Value fill = rewriter.create<SplatOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.padding());
|
||||
loadOp = rewriter.create<vector::MaskedLoadOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
|
||||
read.mask(), fill);
|
||||
} else {
|
||||
loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
|
||||
unbroadcastedVectorType,
|
||||
read.source(), read.indices());
|
||||
}
|
||||
|
||||
// Insert a broadcasting op if required.
|
||||
if (!broadcastedDims.empty()) {
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
read, read.getVectorType(), loadOp.result());
|
||||
read, read.getVectorType(), loadOp->getResult(0));
|
||||
} else {
|
||||
rewriter.replaceOp(read, loadOp.result());
|
||||
rewriter.replaceOp(read, loadOp->getResult(0));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::Optional<unsigned> maxTransferRank;
|
||||
};
|
||||
|
||||
/// Replace a scalar vector.load with a memref.load.
|
||||
@@ -2545,44 +2565,56 @@ struct VectorLoadToMemrefLoadLowering
|
||||
|
||||
/// Progressive lowering of transfer_write. This pattern supports lowering of
|
||||
/// `vector.transfer_write` to `vector.store` if all of the following hold:
|
||||
/// - The op writes to a memref with the default layout.
|
||||
/// - Stride of most minor memref dimension must be 1.
|
||||
/// - Out-of-bounds masking is not required.
|
||||
/// - If the memref's element type is a vector type then it coincides with the
|
||||
/// type of the written value.
|
||||
/// - The permutation map is the minor identity map (neither permutation nor
|
||||
/// broadcasting is allowed).
|
||||
/// - The op has no mask.
|
||||
struct TransferWriteToVectorStoreLowering
|
||||
: public OpRewritePattern<vector::TransferWriteOp> {
|
||||
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
|
||||
TransferWriteToVectorStoreLowering(MLIRContext *context,
|
||||
llvm::Optional<unsigned> maxRank)
|
||||
: OpRewritePattern<vector::TransferWriteOp>(context),
|
||||
maxTransferRank(maxRank) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO: Support non-minor-identity maps
|
||||
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
|
||||
return failure();
|
||||
// Permutations are handled by VectorToSCF or
|
||||
// populateVectorTransferPermutationMapLoweringPatterns.
|
||||
if (!write.permutation_map().isMinorIdentity())
|
||||
return failure();
|
||||
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
// Non-unit strides are handled by VectorToSCF.
|
||||
if (!vector::isLastMemrefDimUnitStride(memRefType))
|
||||
return failure();
|
||||
// `vector.store` supports vector types as memref's elements only when the
|
||||
// type of the vector value being written is the same as the element type.
|
||||
if (memRefType.getElementType().isa<VectorType>() &&
|
||||
memRefType.getElementType() != write.getVectorType())
|
||||
auto memrefElTy = memRefType.getElementType();
|
||||
if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
|
||||
return failure();
|
||||
// Only the default layout is supported by `vector.store`.
|
||||
// TODO: Support non-default layouts.
|
||||
if (!memRefType.getAffineMaps().empty())
|
||||
// Otherwise, element types of the memref and the vector must match.
|
||||
if (!memrefElTy.isa<VectorType>() &&
|
||||
memrefElTy != write.getVectorType().getElementType())
|
||||
return failure();
|
||||
// TODO: When out-of-bounds masking is required, we can create a
|
||||
// MaskedStoreOp.
|
||||
// Out-of-bounds dims are handled by MaterializeTransferMask.
|
||||
if (write.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
if (write.mask())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<vector::StoreOp>(
|
||||
write, write.vector(), write.source(), write.indices());
|
||||
if (write.mask()) {
|
||||
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
|
||||
write, write.source(), write.indices(), write.mask(), write.vector());
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<vector::StoreOp>(
|
||||
write, write.vector(), write.source(), write.indices());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::Optional<unsigned> maxTransferRank;
|
||||
};
|
||||
|
||||
/// Transpose a vector transfer op's `in_bounds` attribute according to given
|
||||
@@ -2624,6 +2656,8 @@ struct TransferReadPermutationLowering
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<unsigned> permutation;
|
||||
AffineMap map = op.permutation_map();
|
||||
if (map.getNumResults() == 0)
|
||||
return failure();
|
||||
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
|
||||
return failure();
|
||||
AffineMap permutationMap =
|
||||
@@ -3680,11 +3714,11 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorTransferLoweringPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns
|
||||
.add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
|
||||
VectorLoadToMemrefLoadLowering>(patterns.getContext());
|
||||
populateVectorTransferPermutationMapLoweringPatterns(patterns);
|
||||
RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
|
||||
patterns.add<TransferReadToVectorLoadLowering,
|
||||
TransferWriteToVectorStoreLowering>(patterns.getContext(),
|
||||
maxTransferRank);
|
||||
patterns.add<VectorLoadToMemrefLoadLowering>(patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
|
||||
|
||||
@@ -1212,18 +1212,19 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
|
||||
// CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32>
|
||||
// CHECK: %[[mask:.*]] = cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32>
|
||||
//
|
||||
// 4. Bitcast to vector form.
|
||||
// 4. Create pass-through vector.
|
||||
// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
|
||||
//
|
||||
// 5. Bitcast to vector form.
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
|
||||
// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
|
||||
// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
|
||||
//
|
||||
// 5. Rewrite as a masked read.
|
||||
// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
|
||||
// 6. Rewrite as a masked read.
|
||||
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
|
||||
// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} :
|
||||
// CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
|
||||
|
||||
//
|
||||
// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// CHECK: %[[linearIndex_b:.*]] = constant dense
|
||||
@@ -1264,8 +1265,9 @@ func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xin
|
||||
}
|
||||
// CHECK-LABEL: func @transfer_read_index_1d
|
||||
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex>
|
||||
// CHECK: %[[C7:.*]] = constant 7
|
||||
// CHECK: %{{.*}} = unrealized_conversion_cast %[[C7]] : index to i64
|
||||
// CHECK: %[[C7:.*]] = constant 7 : index
|
||||
// CHECK: %[[SPLAT:.*]] = splat %[[C7]] : vector<17xindex>
|
||||
// CHECK: %{{.*}} = unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64>
|
||||
|
||||
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
|
||||
// CHECK-SAME: (!llvm.ptr<vector<17xi64>>, vector<17xi1>, vector<17xi64>) -> vector<17xi64>
|
||||
@@ -1384,26 +1386,6 @@ func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @transfer_read_1d_cast(%A : memref<?xi32>, %base: index) -> vector<12xi8> {
|
||||
%c0 = constant 0: i32
|
||||
%v = vector.transfer_read %A[%base], %c0 {in_bounds = [true]} :
|
||||
memref<?xi32>, vector<12xi8>
|
||||
return %v: vector<12xi8>
|
||||
}
|
||||
// CHECK-LABEL: func @transfer_read_1d_cast
|
||||
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<12xi8>
|
||||
//
|
||||
// 1. Bitcast to vector form.
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
|
||||
// CHECK-SAME: (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
|
||||
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
|
||||
// CHECK-SAME: !llvm.ptr<i32> to !llvm.ptr<vector<12xi8>>
|
||||
//
|
||||
// 2. Rewrite as a load.
|
||||
// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr<vector<12xi8>>
|
||||
|
||||
// -----
|
||||
|
||||
func @genbool_1d() -> vector<8xi1> {
|
||||
%0 = vector.constant_mask [4] : vector<8xi1>
|
||||
return %0 : vector<8xi1>
|
||||
|
||||
@@ -1094,11 +1094,12 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
|
||||
|
||||
// -----
|
||||
|
||||
func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
|
||||
func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>,
|
||||
%i : index, %j : index, %value : vector<8xf32>) {
|
||||
// expected-error@+1 {{'vector.store' op base memref should have a default identity layout}}
|
||||
vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
|
||||
// expected-error@+1 {{'vector.store' op most minor memref dim must have unit stride}}
|
||||
vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>,
|
||||
vector<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
@@ -114,14 +114,11 @@ func @transfer_not_inbounds(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: transfer_read/write cannot be lowered to vector.load/store because the
|
||||
// memref has a non-default layout.
|
||||
// CHECK-LABEL: func @transfer_nondefault_layout(
|
||||
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>,
|
||||
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
|
||||
// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {in_bounds = [true]} : memref<8x8xf32, #{{.*}}>, vector<4xf32>
|
||||
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32, #{{.*}}>
|
||||
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32>
|
||||
// CHECK-NEXT: vector.store %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32>
|
||||
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
|
||||
@@ -436,6 +436,7 @@ struct TestVectorTransferLoweringPatterns
|
||||
void runOnFunction() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorTransferLoweringPatterns(patterns);
|
||||
populateVectorTransferPermutationMapLoweringPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user