From d1a9e9a7cbad4044ccc8e08d0217c23aca417714 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 17 Jul 2021 14:01:48 +0900 Subject: [PATCH] [mlir][vector] Remove vector.transfer_read/write to LLVM lowering This simplifies the vector to LLVM lowering. Previously, both vector.load/store and vector.transfer_read/write lowered directly to LLVM. With this commit, there is a single path to LLVM vector load/store instructions and vector.transfer_read/write ops must first be lowered to vector.load/store ops. * Remove vector.transfer_read/write to LLVM lowering. * Allow non-unit memref strides on all but the most minor dimension for vector.load/store ops. * Add maxTransferRank option to populateVectorTransferLoweringPatterns. * vector.transfer_reads with changing element type can no longer be lowered to LLVM. (This functionality is needed only for SPIRV.) Differential Revision: https://reviews.llvm.org/D106118 --- mlir/include/mlir/Dialect/Vector/VectorOps.h | 13 +- mlir/include/mlir/Dialect/Vector/VectorOps.td | 18 +- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 178 +----------------- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 + .../Linalg/Transforms/CodegenStrategy.cpp | 2 +- mlir/lib/Dialect/Vector/VectorOps.cpp | 36 +++- mlir/lib/Dialect/Vector/VectorTransforms.cpp | 112 +++++++---- .../VectorToLLVM/vector-to-llvm.mlir | 34 +--- mlir/test/Dialect/Vector/invalid.mlir | 7 +- .../Vector/vector-transfer-lowering.mlir | 7 +- .../Dialect/Vector/TestVectorTransforms.cpp | 1 + 11 files changed, 151 insertions(+), 260 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h index b7d2d0c0eaec..cf53e8fcff97 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -62,9 +62,12 @@ void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns); /// Collect a set of transfer read/write lowering patterns. /// /// These patterns lower transfer ops to simpler ops like `vector.load`, -/// `vector.store` and `vector.broadcast`. Includes all patterns of -/// populateVectorTransferPermutationMapLoweringPatterns. -void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns); +/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank +/// of a most `maxTransferRank` are lowered. This is useful when combined with +/// VectorToSCF, which reduces the rank of vector transfer ops. +void populateVectorTransferLoweringPatterns( + RewritePatternSet &patterns, + llvm::Optional maxTransferRank = llvm::None); /// Collect a set of transfer read/write lowering patterns that simplify the /// permutation map (e.g., converting it to a minor identity map) by inserting @@ -185,6 +188,10 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef values); Value getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector); +/// Return true if the last dimension of the MemRefType has unit stride. Also +/// return true for memrefs with no strides. +bool isLastMemrefDimUnitStride(MemRefType type); + namespace impl { /// Build the default minor identity map suitable for a vector transfer. This /// also handles the case memref<... x vector<...>> -> vector<...> in which the diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 9fc20647c81d..911a9c60c145 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1409,9 +1409,9 @@ def Vector_LoadOp : Vector_Op<"load"> { based on the element type of the memref. The shape of the result vector type determines the shape of the slice read from the start memory address. The elements along each dimension of the slice are strided by the memref - strides. Only memref with default strides are allowed. These constraints - guarantee that elements read along the first dimension of the slice are - contiguous in memory. + strides. Only unit strides are allowed along the most minor memref + dimension. These constraints guarantee that elements read along the first + dimension of the slice are contiguous in memory. The memref element type can be a scalar or a vector type. If the memref element type is a scalar, it should match the element type of the result @@ -1470,6 +1470,8 @@ def Vector_LoadOp : Vector_Op<"load"> { } }]; + let hasFolder = 1; + let assemblyFormat = "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"; } @@ -1484,9 +1486,9 @@ def Vector_StoreOp : Vector_Op<"store"> { memref dimension based on the element type of the memref. The shape of the vector value to store determines the shape of the slice written from the start memory address. The elements along each dimension of the slice are - strided by the memref strides. Only memref with default strides are allowed. - These constraints guarantee that elements written along the first dimension - of the slice are contiguous in memory. + strided by the memref strides. Only unit strides are allowed along the most + minor memref dimension. These constraints guarantee that elements written + along the first dimension of the slice are contiguous in memory. The memref element type can be a scalar or a vector type. If the memref element type is a scalar, it should match the element type of the value @@ -1544,6 +1546,8 @@ def Vector_StoreOp : Vector_Op<"store"> { } }]; + let hasFolder = 1; + let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict " "`:` type($base) `,` type($valueToStore)"; } @@ -1601,6 +1605,7 @@ def Vector_MaskedLoadOp : let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; + let hasFolder = 1; } def Vector_MaskedStoreOp : @@ -1653,6 +1658,7 @@ def Vector_MaskedStoreOp : "$base `[` $indices `]` `,` $mask `,` $valueToStore " "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; + let hasFolder = 1; } def Vector_GatherOp : diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index b1256cb4f613..53ce5ca3d452 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -130,18 +130,6 @@ static unsigned getAssumedAlignment(Value value) { } return align; } -// Helper that returns data layout alignment of a memref associated with a -// transfer op, including additional information from assume_alignment calls -// on the source of the transfer -LogicalResult getTransferOpAlignment(LLVMTypeConverter &typeConverter, - VectorTransferOpInterface xfer, - unsigned &align) { - if (failed(getMemRefAlignment( - typeConverter, xfer.getShapedType().cast(), align))) - return failure(); - align = std::max(align, getAssumedAlignment(xfer.source())); - return success(); -} // Helper that returns data layout alignment of a memref associated with a // load, store, scatter, or gather op, including additional information from @@ -181,79 +169,6 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, return rewriter.create(loc, pType, ptr); } -static LogicalResult -replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferReadOp xferOp, - ArrayRef operands, Value dataPtr) { - unsigned align; - if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) - return failure(); - rewriter.replaceOpWithNewOp(xferOp, dataPtr, align); - return success(); -} - -static LogicalResult -replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferReadOp xferOp, ArrayRef operands, - Value dataPtr, Value mask) { - Type vecTy = typeConverter.convertType(xferOp.getVectorType()); - if (!vecTy) - return failure(); - - auto adaptor = TransferReadOpAdaptor(operands, xferOp->getAttrDictionary()); - Value fill = rewriter.create(loc, vecTy, adaptor.padding()); - - unsigned align; - if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) - return failure(); - rewriter.replaceOpWithNewOp( - xferOp, vecTy, dataPtr, mask, ValueRange{fill}, - rewriter.getI32IntegerAttr(align)); - return success(); -} - -static LogicalResult -replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferWriteOp xferOp, - ArrayRef operands, Value dataPtr) { - unsigned align; - if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) - return failure(); - auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); - rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, - align); - return success(); -} - -static LogicalResult -replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &typeConverter, Location loc, - TransferWriteOp xferOp, ArrayRef operands, - Value dataPtr, Value mask) { - unsigned align; - if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) - return failure(); - - auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); - rewriter.replaceOpWithNewOp( - xferOp, adaptor.vector(), dataPtr, mask, - rewriter.getI32IntegerAttr(align)); - return success(); -} - -static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, - ArrayRef operands) { - return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary()); -} - -static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, - ArrayRef operands) { - return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); -} - namespace { /// Conversion pattern for a vector.bitcast. @@ -1026,15 +941,6 @@ public: } }; -/// Return true if the last dimension of the MemRefType has unit stride. Also -/// return true for memrefs with no strides. -static bool isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - return succeeded(successStrides) && (strides.empty() || strides.back() == 1); -} - /// Returns the strides if the memory underlying `memRefType` has a contiguous /// static layout. static llvm::Optional> @@ -1145,83 +1051,6 @@ public: } }; -/// Conversion pattern that converts a 1-D vector transfer read/write op into a -/// a masked or unmasked read/write. -template -class VectorTransferConversion : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(ConcreteOp xferOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto adaptor = getTransferOpAdapter(xferOp, operands); - - if (xferOp.getVectorType().getRank() > 1 || xferOp.indices().empty()) - return failure(); - if (xferOp.permutation_map() != - AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), - xferOp.getVectorType().getRank(), - xferOp->getContext())) - return failure(); - auto memRefType = xferOp.getShapedType().template dyn_cast(); - if (!memRefType) - return failure(); - // Last dimension must be contiguous. (Otherwise: Use VectorToSCF.) - if (!isLastMemrefDimUnitStride(memRefType)) - return failure(); - // Out-of-bounds dims are handled by MaterializeTransferMask. - if (xferOp.hasOutOfBoundsDim()) - return failure(); - - auto toLLVMTy = [&](Type t) { - return this->getTypeConverter()->convertType(t); - }; - - Location loc = xferOp->getLoc(); - - if (auto memrefVectorElementType = - memRefType.getElementType().template dyn_cast()) { - // Memref has vector element type. - if (memrefVectorElementType.getElementType() != - xferOp.getVectorType().getElementType()) - return failure(); -#ifndef NDEBUG - // Check that memref vector type is a suffix of 'vectorType. - unsigned memrefVecEltRank = memrefVectorElementType.getRank(); - unsigned resultVecRank = xferOp.getVectorType().getRank(); - assert(memrefVecEltRank <= resultVecRank); - // TODO: Move this to isSuffix in Vector/Utils.h. - unsigned rankOffset = resultVecRank - memrefVecEltRank; - auto memrefVecEltShape = memrefVectorElementType.getShape(); - auto resultVecShape = xferOp.getVectorType().getShape(); - for (unsigned i = 0; i < memrefVecEltRank; ++i) - assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && - "memref vector element shape should match suffix of vector " - "result shape."); -#endif // ifndef NDEBUG - } - - // Get the source/dst address as an LLVM vector pointer. - VectorType vtp = xferOp.getVectorType(); - Value dataPtr = this->getStridedElementPtr( - loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); - Value vectorDataPtr = - castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp)); - - // Rewrite as an unmasked masked read / write. - if (!xferOp.mask()) - return replaceTransferOpWithLoadOrStore(rewriter, - *this->getTypeConverter(), loc, - xferOp, operands, vectorDataPtr); - - // Rewrite as a masked read / write. - return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, - xferOp, operands, vectorDataPtr, - xferOp.mask()); - } -}; - class VectorPrintOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1450,9 +1279,10 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorLoadStoreConversion, VectorGatherOpConversion, VectorScatterOpConversion, - VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorTransferConversion, - VectorTransferConversion>(converter); + VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( + converter); + // Transfer ops with rank > 1 are handled by VectorToSCF. + populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } void mlir::populateVectorToLLVMMatrixConversionPatterns( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 34359ef863e7..1a708dc4da6c 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -64,6 +64,8 @@ void LowerVectorToLLVMPass::runOnOperation() { populateVectorToVectorCanonicalizationPatterns(patterns); populateVectorContractLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns); + // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. + populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } @@ -71,6 +73,7 @@ void LowerVectorToLLVMPass::runOnOperation() { LLVMTypeConverter converter(&getContext()); RewritePatternSet patterns(&getContext()); populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations); + populateVectorTransferLoweringPatterns(patterns); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns, reassociateFPReductions); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp index 93a8b475a94c..cd4d525d6a90 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -89,7 +89,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { .add( vectorTransformsOptions, context); - vector::populateVectorTransferLoweringPatterns( + vector::populateVectorTransferPermutationMapLoweringPatterns( vectorContractLoweringPatterns); (void)applyPatternsAndFoldGreedily( func, std::move(vectorContractLoweringPatterns)); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 9fbc6c3711d8..045fbab987b6 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -102,6 +102,15 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind, return false; } +/// Return true if the last dimension of the MemRefType has unit stride. Also +/// return true for memrefs with no strides. +bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) { + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + return succeeded(successStrides) && (strides.empty() || strides.back() == 1); +} + //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -2953,9 +2962,8 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, MemRefType memRefTy) { - auto affineMaps = memRefTy.getAffineMaps(); - if (!affineMaps.empty()) - return op->emitOpError("base memref should have a default identity layout"); + if (!isLastMemrefDimUnitStride(memRefTy)) + return op->emitOpError("most minor memref dim must have unit stride"); return success(); } @@ -2981,6 +2989,12 @@ static LogicalResult verify(vector::LoadOp op) { return success(); } +OpFoldResult LoadOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// @@ -3008,6 +3022,11 @@ static LogicalResult verify(vector::StoreOp op) { return success(); } +LogicalResult StoreOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // MaskedLoadOp //===----------------------------------------------------------------------===// @@ -3056,6 +3075,12 @@ void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +OpFoldResult MaskedLoadOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // MaskedStoreOp //===----------------------------------------------------------------------===// @@ -3101,6 +3126,11 @@ void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +LogicalResult MaskedStoreOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // GatherOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 4bd1ee15dece..2a99eb6e7063 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2464,26 +2464,34 @@ struct TransferWriteInsertPattern /// Progressive lowering of transfer_read. This pattern supports lowering of /// `vector.transfer_read` to a combination of `vector.load` and /// `vector.broadcast` if all of the following hold: -/// - The op reads from a memref with the default layout. +/// - Stride of most minor memref dimension must be 1. /// - Out-of-bounds masking is not required. /// - If the memref's element type is a vector type then it coincides with the /// result type. /// - The permutation map doesn't perform permutation (broadcasting is allowed). -/// - The op has no mask. struct TransferReadToVectorLoadLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + TransferReadToVectorLoadLowering(MLIRContext *context, + llvm::Optional maxRank) + : OpRewritePattern(context), + maxTransferRank(maxRank) {} LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { + if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) + return failure(); SmallVector broadcastedDims; - // TODO: Support permutations. + // Permutations are handled by VectorToSCF or + // populateVectorTransferPermutationMapLoweringPatterns. if (!read.permutation_map().isMinorIdentityWithBroadcasting( &broadcastedDims)) return failure(); auto memRefType = read.getShapedType().dyn_cast(); if (!memRefType) return failure(); + // Non-unit strides are handled by VectorToSCF. + if (!vector::isLastMemrefDimUnitStride(memRefType)) + return failure(); // If there is broadcasting involved then we first load the unbroadcasted // vector, and then broadcast it with `vector.broadcast`. @@ -2497,32 +2505,44 @@ struct TransferReadToVectorLoadLowering // `vector.load` supports vector types as memref's elements only when the // resulting vector type is the same as the element type. - if (memRefType.getElementType().isa() && - memRefType.getElementType() != unbroadcastedVectorType) + auto memrefElTy = memRefType.getElementType(); + if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) return failure(); - // Only the default layout is supported by `vector.load`. - // TODO: Support non-default layouts. - if (!memRefType.getAffineMaps().empty()) - return failure(); - // TODO: When out-of-bounds masking is required, we can create a - // MaskedLoadOp. - if (read.hasOutOfBoundsDim()) - return failure(); - if (read.mask()) + // Otherwise, element types of the memref and the vector must match. + if (!memrefElTy.isa() && + memrefElTy != read.getVectorType().getElementType()) return failure(); - auto loadOp = rewriter.create( - read.getLoc(), unbroadcastedVectorType, read.source(), read.indices()); + // Out-of-bounds dims are handled by MaterializeTransferMask. + if (read.hasOutOfBoundsDim()) + return failure(); + + // Create vector load op. + Operation *loadOp; + if (read.mask()) { + Value fill = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.padding()); + loadOp = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), + read.mask(), fill); + } else { + loadOp = rewriter.create(read.getLoc(), + unbroadcastedVectorType, + read.source(), read.indices()); + } + // Insert a broadcasting op if required. if (!broadcastedDims.empty()) { rewriter.replaceOpWithNewOp( - read, read.getVectorType(), loadOp.result()); + read, read.getVectorType(), loadOp->getResult(0)); } else { - rewriter.replaceOp(read, loadOp.result()); + rewriter.replaceOp(read, loadOp->getResult(0)); } return success(); } + + llvm::Optional maxTransferRank; }; /// Replace a scalar vector.load with a memref.load. @@ -2545,44 +2565,56 @@ struct VectorLoadToMemrefLoadLowering /// Progressive lowering of transfer_write. This pattern supports lowering of /// `vector.transfer_write` to `vector.store` if all of the following hold: -/// - The op writes to a memref with the default layout. +/// - Stride of most minor memref dimension must be 1. /// - Out-of-bounds masking is not required. /// - If the memref's element type is a vector type then it coincides with the /// type of the written value. /// - The permutation map is the minor identity map (neither permutation nor /// broadcasting is allowed). -/// - The op has no mask. struct TransferWriteToVectorStoreLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + TransferWriteToVectorStoreLowering(MLIRContext *context, + llvm::Optional maxRank) + : OpRewritePattern(context), + maxTransferRank(maxRank) {} LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { - // TODO: Support non-minor-identity maps + if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) + return failure(); + // Permutations are handled by VectorToSCF or + // populateVectorTransferPermutationMapLoweringPatterns. if (!write.permutation_map().isMinorIdentity()) return failure(); auto memRefType = write.getShapedType().dyn_cast(); if (!memRefType) return failure(); + // Non-unit strides are handled by VectorToSCF. + if (!vector::isLastMemrefDimUnitStride(memRefType)) + return failure(); // `vector.store` supports vector types as memref's elements only when the // type of the vector value being written is the same as the element type. - if (memRefType.getElementType().isa() && - memRefType.getElementType() != write.getVectorType()) + auto memrefElTy = memRefType.getElementType(); + if (memrefElTy.isa() && memrefElTy != write.getVectorType()) return failure(); - // Only the default layout is supported by `vector.store`. - // TODO: Support non-default layouts. - if (!memRefType.getAffineMaps().empty()) + // Otherwise, element types of the memref and the vector must match. + if (!memrefElTy.isa() && + memrefElTy != write.getVectorType().getElementType()) return failure(); - // TODO: When out-of-bounds masking is required, we can create a - // MaskedStoreOp. + // Out-of-bounds dims are handled by MaterializeTransferMask. if (write.hasOutOfBoundsDim()) return failure(); - if (write.mask()) - return failure(); - rewriter.replaceOpWithNewOp( - write, write.vector(), write.source(), write.indices()); + if (write.mask()) { + rewriter.replaceOpWithNewOp( + write, write.source(), write.indices(), write.mask(), write.vector()); + } else { + rewriter.replaceOpWithNewOp( + write, write.vector(), write.source(), write.indices()); + } return success(); } + + llvm::Optional maxTransferRank; }; /// Transpose a vector transfer op's `in_bounds` attribute according to given @@ -2624,6 +2656,8 @@ struct TransferReadPermutationLowering PatternRewriter &rewriter) const override { SmallVector permutation; AffineMap map = op.permutation_map(); + if (map.getNumResults() == 0) + return failure(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) return failure(); AffineMap permutationMap = @@ -3680,11 +3714,11 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( } void mlir::vector::populateVectorTransferLoweringPatterns( - RewritePatternSet &patterns) { - patterns - .add(patterns.getContext()); - populateVectorTransferPermutationMapLoweringPatterns(patterns); + RewritePatternSet &patterns, llvm::Optional maxTransferRank) { + patterns.add(patterns.getContext(), + maxTransferRank); + patterns.add(patterns.getContext()); } void mlir::vector::populateVectorMultiReductionLoweringPatterns( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 329b79a19508..afb007c9e6b3 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1212,18 +1212,19 @@ func @transfer_read_1d(%A : memref, %base: index) -> vector<17xf32> { // CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32> // CHECK: %[[mask:.*]] = cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32> // -// 4. Bitcast to vector form. +// 4. Create pass-through vector. +// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32> +// +// 5. Bitcast to vector form. // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : // CHECK-SAME: !llvm.ptr to !llvm.ptr> // -// 5. Rewrite as a masked read. -// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32> +// 6. Rewrite as a masked read. // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]], // CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} : // CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xf32>) -> vector<17xf32> - // // 1. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // CHECK: %[[linearIndex_b:.*]] = constant dense @@ -1264,8 +1265,9 @@ func @transfer_read_index_1d(%A : memref, %base: index) -> vector<17xin } // CHECK-LABEL: func @transfer_read_index_1d // CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex> -// CHECK: %[[C7:.*]] = constant 7 -// CHECK: %{{.*}} = unrealized_conversion_cast %[[C7]] : index to i64 +// CHECK: %[[C7:.*]] = constant 7 : index +// CHECK: %[[SPLAT:.*]] = splat %[[C7]] : vector<17xindex> +// CHECK: %{{.*}} = unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64> // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : // CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xi64>) -> vector<17xi64> @@ -1384,26 +1386,6 @@ func @transfer_read_1d_mask(%A : memref, %base : index) -> vector<5xf32> // ----- -func @transfer_read_1d_cast(%A : memref, %base: index) -> vector<12xi8> { - %c0 = constant 0: i32 - %v = vector.transfer_read %A[%base], %c0 {in_bounds = [true]} : - memref, vector<12xi8> - return %v: vector<12xi8> -} -// CHECK-LABEL: func @transfer_read_1d_cast -// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<12xi8> -// -// 1. Bitcast to vector form. -// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : -// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr -// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : -// CHECK-SAME: !llvm.ptr to !llvm.ptr> -// -// 2. Rewrite as a load. -// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr> - -// ----- - func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1> return %0 : vector<8xi1> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index e8ad86f5fffb..7fb4ecb5b0d3 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1094,11 +1094,12 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> // ----- -func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>, +func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>, %i : index, %j : index, %value : vector<8xf32>) { - // expected-error@+1 {{'vector.store' op base memref should have a default identity layout}} - vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>, + // expected-error@+1 {{'vector.store' op most minor memref dim must have unit stride}} + vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>>, vector<8xf32> + return } // ----- diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir index 931c3ba91774..910100d61af4 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -114,14 +114,11 @@ func @transfer_not_inbounds(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> // ----- -// TODO: transfer_read/write cannot be lowered to vector.load/store because the -// memref has a non-default layout. // CHECK-LABEL: func @transfer_nondefault_layout( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { -// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 -// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {in_bounds = [true]} : memref<8x8xf32, #{{.*}}>, vector<4xf32> -// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32, #{{.*}}> +// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32> +// CHECK-NEXT: vector.store %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32> // CHECK-NEXT: return %[[RES]] : vector<4xf32> // CHECK-NEXT: } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index fbc80116b955..11b56a583cc8 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -436,6 +436,7 @@ struct TestVectorTransferLoweringPatterns void runOnFunction() override { RewritePatternSet patterns(&getContext()); populateVectorTransferLoweringPatterns(patterns); + populateVectorTransferPermutationMapLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };