[mlir][Vector] NFC - Reorganize vector patterns

Vector dialect patterns have grown enormously in the past year to a point where they are now impenetrable.
Start reorganizing them towards finer-grained control.

Differential Revision: https://reviews.llvm.org/D146736
This commit is contained in:
Nicolas Vasilache
2023-03-23 08:32:48 -07:00
parent 637048f122
commit 2bc4c3e920
26 changed files with 3103 additions and 2653 deletions

View File

@@ -110,43 +110,11 @@ void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Collect a set of transfer read/write lowering patterns.
///
/// These patterns lower transfer ops to simpler ops like `vector.load`,
/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
/// of a most `maxTransferRank` are lowered. This is useful when combined with
/// VectorToSCF, which reduces the rank of vector transfer ops.
void populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns,
std::optional<unsigned> maxTransferRank = std::nullopt,
PatternBenefit benefit = 1);
/// These patterns materialize masks for various vector ops such as transfers.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool force32BitVectorIndices,
PatternBenefit benefit = 1);
/// Collects patterns to progressively lower vector.broadcast ops on high-D
/// vectors to low-D vector ops.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Collects patterns to progressively lower vector mask ops into elementary
/// selection and insertion ops.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Collects patterns to progressively lower vector.shape_cast ops on high-D
/// vectors into 1-D/2-D vector ops by generating data movement extract/insert
/// ops.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Collects patterns that lower scalar vector transfer ops to memref loads and
/// stores when beneficial.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
@@ -214,8 +182,8 @@ void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
/// Creates a vector.mask operation around a maskable operation. Returns the
/// vector.mask operation if the mask provided is valid. Otherwise, returns the
/// maskable operation itself.
Operation *maskOperation(OpBuilder &builder, Operation *maskableOp,
Value mask, Value passthru = Value());
Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask,
Value passthru = Value());
/// Creates a vector select operation that picks values from `newValue` or
/// `passthru` for each result vector lane based on `mask`. This utility is used

View File

@@ -0,0 +1,248 @@
//===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
namespace mlir {
class RewritePatternSet;
namespace vector {
//===----------------------------------------------------------------------===//
// Lowering pattern populate functions
//===----------------------------------------------------------------------===//
/// Populate the pattern set with the following patterns:
///
/// [OuterProductOpLowering]
/// Progressively lower a `vector.outerproduct` to linearized
/// `vector.extract` + `vector.fma` + `vector.insert`.
///
/// [ContractionOpLowering]
/// Progressive lowering of ContractionOp.
/// One:
/// %x = vector.contract with at least one free/batch dimension
/// is replaced by:
/// %a = vector.contract with one less free/batch dimension
/// %b = vector.contract with one less free/batch dimension
///
/// [ContractionOpToMatmulOpLowering]
/// Progressively lower a `vector.contract` with row-major matmul semantics to
/// linearized `vector.shape_cast` + `vector.matmul` on the way to
/// `llvm.matrix.multiply`.
///
/// [ContractionOpToDotLowering]
/// Progressively lower a `vector.contract` with row-major matmul semantics to
/// linearized `vector.extract` + `vector.reduce` + `vector.insert`.
///
/// [ContractionOpToOuterProductOpLowering]
/// Progressively lower a `vector.contract` with row-major matmul semantics to
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
///
/// [InnerOuterDimReductionConversion]
/// Rewrites vector.multi_reduction such that all reduction dimensions are
/// either innermost or outermost, by adding the proper vector.transpose
/// operations.
///
/// [ReduceMultiDimReductionRank]
/// Once in innermost or outermost reduction
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
///
/// [TwoDimMultiReductionToElementWise]
/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
/// ops. This also has an opportunity for tree-reduction (in the future).
///
/// [TwoDimMultiReductionToReduction]
/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
/// dimension, unroll the outer dimension to obtain a sequence of extract +
/// vector.reduction + insert. This can further lower to horizontal reduction
/// ops.
///
/// [OneDimMultiReductionToTwoDim]
/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
/// either a parallel or a reduction), we lift them back up to 2-D with a simple
/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
/// thus fully exiting out of the vector.multi_reduction abstraction.
void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
/// [TransferReadToVectorLoadLowering]
/// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D
/// BroadcastOp until dim 1.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
/// [CreateMaskOp]
/// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1.
///
/// [ConstantMaskOp]
/// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until
/// dim 1.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Collects patterns that lower scalar vector transfer ops to memref loads and
/// stores when beneficial.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
/// [ShapeCastOp2DDownCastRewritePattern]
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
/// vectors progressively.
///
/// [ShapeCastOp2DUpCastRewritePattern]
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
/// vectors progressively.
///
/// [ShapeCastOpRewritePattern]
/// Reference lowering to fully unrolled sequences of single element ExtractOp +
/// InsertOp. Note that applying this pattern can almost always be considered a
/// performance bug.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
/// [TransposeOpLowering]
///
/// [TransposeOp2DToShuffleLowering]
///
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
VectorTransformsOptions options,
PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
/// [TransferReadToVectorLoadLowering]
/// Progressive lowering of transfer_read.This pattern supports lowering of
/// `vector.transfer_read` to a combination of `vector.load` and
/// `vector.broadcast`
///
/// [TransferWriteToVectorStoreLowering]
/// Progressive lowering of transfer_write. This pattern supports lowering of
/// `vector.transfer_write` to `vector.store`
///
/// [VectorLoadToMemrefLoadLowering]
/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
///
/// [VectorStoreToMemrefStoreLowering]
/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
///
/// These patterns lower transfer ops to simpler ops like `vector.load`,
/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
/// of a most `maxTransferRank` are lowered. This is useful when combined with
/// VectorToSCF, which reduces the rank of vector transfer ops.
void populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns,
std::optional<unsigned> maxTransferRank = std::nullopt,
PatternBenefit benefit = 1);
/// Collect a set of transfer read/write lowering patterns that simplify the
/// permutation map (e.g., converting it to a minor identity map) by inserting
/// broadcasts and transposes. More specifically:
///
/// [TransferReadPermutationLowering]
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identity +
/// vector.transpose op.
/// Ex:
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (0, d1)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (d1, 0)
/// vector.transpose %v, [1, 0]
///
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
/// vector.transpose %v, [0, 1, 3, 2, 4]
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
///
/// [TransferWritePermutationLowering]
/// Lower transfer_write op with permutation into a transfer_write with a
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
/// Ex:
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
/// into:
/// %tmp = vector.transpose %v, [2, 0, 1]
/// vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
///
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
/// into:
/// %tmp = vector.transpose %v, [1, 0]
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
///
/// [TransferOpReduceRank]
/// Lower transfer_read op with broadcast in the leading dimensions into
/// transfer_read of lower rank + vector.broadcast.
/// Ex: vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
void populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
/// [ScanToArithOps]
/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
/// vector.extract_strided_slice.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
/// [FlattenGather]
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension. For example:
///
/// [Gather1DToConditionalLoads]
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
/// not its nested `MaskableOpInterface`.
void populateVectorMaskLoweringPatternsForSideEffectingOps(
RewritePatternSet &patterns);
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

View File

@@ -22,12 +22,6 @@ std::unique_ptr<Pass> createVectorBufferizePass();
/// Creates an instance of the `vector.mask` lowering pass.
std::unique_ptr<Pass> createLowerVectorMaskPass();
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
/// not its nested `MaskableOpInterface`.
void populateVectorMaskLoweringPatternsForSideEffectingOps(
RewritePatternSet &patterns);
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

View File

@@ -9,8 +9,8 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
#include <utility>
#include <optional>
#include <utility>
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
@@ -23,42 +23,7 @@ namespace mlir {
class RewritePatternSet;
namespace vector {
//===----------------------------------------------------------------------===//
// Vector transformation options exposed as auxiliary structs.
//===----------------------------------------------------------------------===//
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
/// Option to control the lowering of vector.contract.
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
return *this;
}
/// Option to control the lowering of vector.multi_reduction.
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions &
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
vectorMultiReductionLowering = opt;
return *this;
}
/// Option to control the lowering of vector.transpose.
VectorTransposeLowering vectorTransposeLowering =
VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransposeLowering(VectorTransposeLowering opt) {
vectorTransposeLowering = opt;
return *this;
}
/// Option to control the splitting of vector transfers.
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
vectorTransferSplit = opt;
return *this;
}
};
struct VectorTransformsOptions;
/// Options that control the vector unrolling.
struct UnrollVectorOptions {
@@ -109,45 +74,6 @@ struct UnrollVectorOptions {
// Vector transformation exposed as populate functions over rewrite patterns.
//===----------------------------------------------------------------------===//
/// Insert TransposeLowering patterns into extraction/insertion.
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions options = VectorTransformsOptions(),
PatternBenefit benefit = 1);
/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
/// that all reduction dimensions are either innermost or outermost, by adding
/// the proper vector.transpose operations.
/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
/// form, with an **outermost** reduction dimension, unroll the outer dimension
/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
/// tree-reduction (in the future).
/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
/// with an **innermost** reduction dimension, unroll the outer dimension to
/// obtain a sequence of extract + vector.reduction + insert. This can further
/// lower to horizontal reduction ops.
/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
/// reduction (and are thus missing either a parallel or a reduction), we lift
/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
/// the other patterns can kick in, thus fully exiting out of the
/// vector.multi_reduction abstraction.
void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
/// Collects patterns to progressively lower vector contraction ops on high-D
/// into low-D reduction and product ops.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions options = VectorTransformsOptions(),
PatternBenefit benefit = 1);
/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
/// semantics to a contraction with MMT semantics (matrix matrix multiplication
/// with the RHS transposed). This specific form is meant to have the vector
@@ -174,67 +100,43 @@ void populateVectorContractCanonicalizeMatmulToMMT(
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Collect patterns to convert scan op
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Populate `patterns` with the following patterns.
///
/// - VectorTransferFullPartialRewriter
///
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fast path and a slow path.
///
/// Example (a 2-D vector.transfer_read):
/// ```
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// ```
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// // fast path, direct cast
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// // slow path, not in-bounds vector.transfer or linalg.copy.
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
/// ```
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
void populateVectorTransferFullPartialPatterns(
RewritePatternSet &patterns, const VectorTransformsOptions &options);
//===----------------------------------------------------------------------===//
// Vector.transfer patterns.
//===----------------------------------------------------------------------===//
/// Collect a set of transfer read/write lowering patterns that simplify the
/// permutation map (e.g., converting it to a minor identity map) by inserting
/// broadcasts and transposes. More specifically:
///
/// [TransferReadPermutationLowering]
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identity +
/// vector.transpose op.
/// Ex:
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (0, d1)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (d1, 0)
/// vector.transpose %v, [1, 0]
///
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
/// vector.transpose %v, [0, 1, 3, 2, 4]
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
///
/// [TransferWritePermutationLowering]
/// Lower transfer_write op with permutation into a transfer_write with a
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
/// Ex:
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
/// into:
/// %tmp = vector.transpose %v, [2, 0, 1]
/// vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
///
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
/// into:
/// %tmp = vector.transpose %v, [1, 0]
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
///
/// [TransferOpReduceRank]
/// Lower transfer_read op with broadcast in the leading dimensions into
/// transfer_read of lower rank + vector.broadcast.
/// Ex: vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
void populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Collect a set of patterns to reduce the rank of the operands of vector
/// transfer ops to operate on the largest contigious vector.
@@ -334,220 +236,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options,
PatternBenefit benefit = 1);
/// Expands `vector.gather` ops into a series of conditional scalar loads
/// (`vector.load` for memrefs or `tensor.extract` for tensors). These loads are
/// conditional to avoid out-of-bounds memory accesses and guarded with `scf.if`
/// ops. This lowering path is intended for targets that do not feature
/// dedicated gather ops.
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
//===----------------------------------------------------------------------===//
// Finer-grained patterns exposed for more control over individual lowerings.
//===----------------------------------------------------------------------===//
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
/// may take an extra filter to perform selection at a finer granularity.
struct VectorTransferFullPartialRewriter : public RewritePattern {
using FilterConstraintType =
std::function<LogicalResult(VectorTransferOpInterface op)>;
explicit VectorTransferFullPartialRewriter(
MLIRContext *context,
VectorTransformsOptions options = VectorTransformsOptions(),
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
filter(std::move(filter)) {}
/// Performs the rewrite.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
VectorTransformsOptions options;
FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to:
/// ```
/// %flattened_a = vector.shape_cast %a
/// %flattened_b = vector.shape_cast %b
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
/// %d = vector.shape_cast %%flattened_d
/// %e = add %c, %d
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to a reduction_size-unrolled sequence:
/// ```
/// %at = vector.transpose %a, [1, 0]
/// %bRow0 = vector.extract %b[0]
/// %atRow0 = vector.extract %at[0]
/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
/// ...
/// %bRowK = vector.extract %b[K]
/// %atRowK = vector.extract %at[K]
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to an output-size-unrolled sequence:
/// ```
/// %out = arith.constant ... : vector<MxNxelt_type>
/// %bt = vector.transpose %b, [1, 0]
/// %aRow0 = vector.extract %a[0]
/// %btRow0 = vector.extract %bt[0]
/// %c00 = vector.reduce %atRow0, %bRow0
/// %out00 = vector.insert %c00, %out[0, 0]
/// ...
/// %aRowLast = vector.extract %at[M-1]
/// %btRowLast = vector.extract %b[N-1]
/// %cLastLast = vector.reduce %atRowLast, %bRowLast
/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to Dot and
/// the vector.contract op is a row-major matmul or matvec.
class ContractionOpToDotLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToDotLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of ContractionOp.
///
/// One:
/// %x = vector.contract with at least one free/batch dimension
/// is replaced by:
/// %a = vector.contract with one less free/batch dimension
/// %b = vector.contract with one less free/batch dimension
/// ..
/// %x = combine %a %b ..
/// until a pure contraction is reached (no free/batch dimensions),
/// which is replaced by a dot-product.
///
/// This only kicks in when either VectorTransformsOptions is set
/// to Dot or when other contraction patterns fail.
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
// Lower one parallel dimension.
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, Value mask) const;
// Lower one reduction dimension.
FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
vector::ContractionOp op, Value mask) const;
};
} // namespace vector
} // namespace mlir

View File

@@ -24,17 +24,53 @@ class IfOp;
namespace vector {
//===----------------------------------------------------------------------===//
// Vector transformation options exposed as auxiliary structs.
//===----------------------------------------------------------------------===//
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
/// Option to control the lowering of vector.contract.
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
return *this;
}
/// Option to control the lowering of vector.multi_reduction.
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions &
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
vectorMultiReductionLowering = opt;
return *this;
}
/// Option to control the lowering of vector.transpose.
VectorTransposeLowering vectorTransposeLowering =
VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransposeLowering(VectorTransposeLowering opt) {
vectorTransposeLowering = opt;
return *this;
}
/// Option to control the splitting of vector transfers.
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
vectorTransferSplit = opt;
return *this;
}
};
//===----------------------------------------------------------------------===//
// Standalone transformations and helpers.
//===----------------------------------------------------------------------===//
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fastpath and a slowpath.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
/// newly created conditional upon function return.
/// To accomodate for the fact that the original vector.transfer indexing may be
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// scf.if op returns a view and values of type index.
/// At this time, only vector.transfer_read case is implemented.
/// Split a vector.transfer operation into an in-bounds (i.e., no
/// out-of-bounds masking) fastpath and a slowpath. If `ifOp` is not null and
/// the result is `success, the `ifOp` points to the newly created conditional
/// upon function return. To accomodate for the fact that the original
/// vector.transfer indexing may be arbitrary and the slow path indexes
/// @[0...0] in the temporary buffer, the scf.if op returns a view and values
/// of type index. At this time, only vector.transfer_read case is
/// implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
@@ -51,15 +87,16 @@ namespace vector {
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ...
/// true]}
/// ```
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
/// 2. the rank of the `xferOp.memref()` and the rank of the
/// `xferOp.vector()` must be equal. This will be relaxed in the future but
/// requires rank-reducing subviews.
LogicalResult splitFullAndPartialTransfer(
RewriterBase &b, VectorTransferOpInterface xferOp,
VectorTransformsOptions options = VectorTransformsOptions(),

View File

@@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"

View File

@@ -19,6 +19,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
@@ -64,10 +65,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns,
VectorTransformsOptions());
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

View File

@@ -10,8 +10,8 @@
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include <optional>
#include <type_traits>
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
@@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"

View File

@@ -20,5 +20,5 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRTransformDialectUtils
MLIRVectorDialect
MLIRVectorTransforms
)

View File

@@ -26,6 +26,7 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"

View File

@@ -7,13 +7,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Parser/Parser.h"
@@ -82,10 +83,9 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
// In the future we may want to more finely select particular stages.
// Stage 1: contraction lowerings.
patterns.add<mlir::vector::ContractionOpToOuterProductOpLowering,
mlir::vector::ContractionOpToMatmulOpLowering,
mlir::vector::ContractionOpLowering>(vectorTransformOptions,
ctx);
populateVectorContractLoweringPatterns(
patterns, vectorTransformOptions, /*benefit=*/1,
/*disableOuterProductLowering*/ true);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
// Stage 2: multi-reduction lowerings.
@@ -93,8 +93,7 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
// Stage 3: Rewrite vector.transfer into full and partial parts.
patterns.add<vector::VectorTransferFullPartialRewriter>(
ctx, vectorTransformOptions);
populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
// Stage 4: Lower vector transfers.
vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank);
@@ -107,8 +106,8 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
vector::populateVectorShapeCastLoweringPatterns(patterns);
// Stage 7: Lower vector.transpose.
vector::populateVectorTransposeLoweringPatterns(patterns,
vectorTransformOptions);
vector::populateVectorTransposeLoweringPatterns(
patterns, vectorTransformOptions, /*benefit=*/1);
if (getTransposeAvx2Lowering())
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
patterns, avx2LoweringOptions, /*benefit=*/10);

View File

@@ -1,14 +1,20 @@
add_mlir_dialect_library(MLIRVectorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
LowerVectorGather.cpp
LowerVectorMask.cpp
LowerVectorMultiReduction.cpp
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
VectorDistribute.cpp
VectorDropLeadUnitDim.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorMultiDimReductionTransforms.cpp
VectorTransferOpTransforms.cpp
VectorTransferSplitRewritePatterns.cpp
VectorTransferPermutationMapRewritePatterns.cpp
VectorTransforms.cpp
VectorUnroll.cpp

View File

@@ -0,0 +1,156 @@
//===- LowerVectorBroadcast.cpp - Lower 'vector.broadcast' operation ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.broadcast' operation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "vector-broadcast-lowering"
using namespace mlir;
using namespace mlir::vector;
namespace {
/// Progressive lowering of BroadcastOp.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType dstType = op.getResultVectorType();
VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
Type eltType = dstType.getElementType();
// Scalar to any vector can use splat.
if (!srcType) {
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
return success();
}
// Determine rank of source and destination.
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
Value ext;
if (srcRank == 0)
ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
else
ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
return success();
}
// Duplicate this rank.
// For example:
// %x = broadcast %y : k-D to n-D, k < n
// becomes:
// %b = broadcast %y : k-D to (n-1)-D
// %x = [%b,%b,%b,%b] : n-D
// becomes:
// %b = [%y,%y] : (n-1)-D
// %x = [%b,%b,%b,%b] : n-D
if (srcRank < dstRank) {
// Duplication.
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType);
Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
rewriter.replaceOp(op, result);
return success();
}
// Find non-matching dimension, if any.
assert(srcRank == dstRank);
int64_t m = -1;
for (int64_t r = 0; r < dstRank; r++)
if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
m = r;
break;
}
// All trailing dimensions are the same. Simply pass through.
if (m == -1) {
rewriter.replaceOp(op, op.getSource());
return success();
}
// Any non-matching dimension forces a stretch along this rank.
// For example:
// %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
// becomes:
// %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
// %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
// %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
// %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
// %x = [%a,%b,%c,%d]
// becomes:
// %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
// %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
// %a = [%u, %v]
// ..
// %x = [%a,%b,%c,%d]
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType);
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
if (m == 0) {
// Stetch at start.
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
} else {
// Stetch not at start.
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
}
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
void mlir::vector::populateVectorBroadcastLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,173 @@
//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.scan' operation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "vector-broadcast-lowering"
using namespace mlir;
using namespace mlir::vector;
namespace {
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension. For example:
/// ```
/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
/// ... into vector<2x3xf32>
///
/// ==>
///
/// %0 = arith.constant dense<0.0> : vector<2x3xf32>
/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
/// ```
///
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
struct FlattenGather : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
VectorType resultTy = op.getType();
if (resultTy.getRank() < 2)
return rewriter.notifyMatchFailure(op, "already flat");
Location loc = op.getLoc();
Value indexVec = op.getIndexVec();
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();
Value result = rewriter.create<arith::ConstantOp>(
loc, resultTy, rewriter.getZeroAttr(resultTy));
Type subTy = VectorType::get(resultTy.getShape().drop_front(),
resultTy.getElementType());
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
int64_t thisIdx[1] = {i};
Value indexSubVec =
rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
Value maskSubVec =
rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
Value passThruSubVec =
rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
Value subGather = rewriter.create<vector::GatherOp>(
loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
passThruSubVec);
result =
rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
}
rewriter.replaceOp(op, result);
return success();
}
};
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
VectorType resultTy = op.getType();
if (resultTy.getRank() != 1)
return rewriter.notifyMatchFailure(op, "unsupported rank");
Location loc = op.getLoc();
Type elemTy = resultTy.getElementType();
// Vector type with a single element. Used to generate `vector.loads`.
VectorType elemVecTy = VectorType::get({1}, elemTy);
Value condMask = op.getMask();
Value base = op.getBase();
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
op.getIndexVec());
auto baseOffsets = llvm::to_vector(op.getIndices());
Value lastBaseOffset = baseOffsets.back();
Value result = op.getPassThru();
// Emit a conditional access for each vector element.
for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
int64_t thisIdx[1] = {i};
Value condition =
rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
baseOffsets.back() =
rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
auto loadBuilder = [&](OpBuilder &b, Location loc) {
Value extracted;
if (isa<MemRefType>(base.getType())) {
// `vector.load` does not support scalar result; emit a vector load
// and extract the single result instead.
Value load =
b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
int64_t zeroIdx[1] = {0};
extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
} else {
extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
}
Value newResult =
b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
b.create<scf::YieldOp>(loc, newResult);
};
auto passThruBuilder = [result](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, result);
};
result =
rewriter
.create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
/*elseBuilder=*/passThruBuilder)
.getResult(0);
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
void mlir::vector::populateVectorGatherLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
benefit);
}

View File

@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilitites to lower the
// This file implements target-independent rewrites and utilities to lower the
// 'vector.mask' operation.
//
//===----------------------------------------------------------------------===//
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -30,6 +31,147 @@ namespace vector {
using namespace mlir;
using namespace mlir::vector;
//===----------------------------------------------------------------------===//
// populateVectorMaskOpLoweringPatterns
//===----------------------------------------------------------------------===//
namespace {
/// Progressive lowering of CreateMaskOp.
/// One:
/// %x = vector.create_mask %a, ... : vector<dx...>
/// is replaced by:
/// %l = vector.create_mask ... : vector<...> ; one lower rank
/// %0 = arith.cmpi "slt", %ci, %a |
/// %1 = select %0, %l, %zeroes |
/// %r = vector.insert %1, %pr [i] | d-times
/// %x = ....
/// until a one-dimensional vector is reached.
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getResult().getType().cast<VectorType>();
int64_t rank = dstType.getRank();
if (rank <= 1)
return rewriter.notifyMatchFailure(
op, "0-D and 1-D vectors are handled separately");
auto loc = op.getLoc();
auto eltType = dstType.getElementType();
int64_t dim = dstType.getDimSize(0);
Value idx = op.getOperand(0);
VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
Value trueVal = rewriter.create<vector::CreateMaskOp>(
loc, lowType, op.getOperands().drop_front());
Value falseVal = rewriter.create<arith::ConstantOp>(
loc, lowType, rewriter.getZeroAttr(lowType));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < dim; d++) {
Value bnd =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
bnd, idx);
Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
auto pos = rewriter.getI64ArrayAttr(d);
result =
rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
}
rewriter.replaceOp(op, result);
return success();
}
};
/// Progressive lowering of ConstantMaskOp.
/// One:
/// %x = vector.constant_mask [a,b]
/// is replaced by:
/// %z = zero-result
/// %l = vector.constant_mask [b]
/// %4 = vector.insert %l, %z[0]
/// ..
/// %x = vector.insert %l, %..[a-1]
/// until a one-dimensional vector is reached. All these operations
/// will be folded at LLVM IR level.
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto dstType = op.getType();
auto eltType = dstType.getElementType();
auto dimSizes = op.getMaskDimSizes();
int64_t rank = dstType.getRank();
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(
VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
ArrayRef<bool>{value}));
return success();
}
// Scalable constant masks can only be lowered for the "none set" case.
if (dstType.cast<VectorType>().isScalable()) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(dstType, false));
return success();
}
int64_t trueDim = std::min(dstType.getDimSize(0),
dimSizes[0].cast<IntegerAttr>().getInt());
if (rank == 1) {
// Express constant 1-D case in explicit vector form:
// [T,..,T,F,..,F].
SmallVector<bool> values(dstType.getDimSize(0));
for (int64_t d = 0; d < trueDim; d++)
values[d] = true;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType, rewriter.getBoolVectorAttr(values));
return success();
}
VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
SmallVector<int64_t> newDimSizes;
for (int64_t r = 1; r < rank; r++)
newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDim; d++) {
auto pos = rewriter.getI64ArrayAttr(d);
result =
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
void mlir::vector::populateVectorMaskOpLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
patterns.getContext(), benefit);
}
//===----------------------------------------------------------------------===//
// populateVectorMaskLoweringPatternsForSideEffectingOps
//===----------------------------------------------------------------------===//
namespace {
/// The `MaskOpRewritePattern` implements a pattern that follows a two-fold

View File

@@ -1,4 +1,4 @@
//===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
//
/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
/// This file implements target-independent rewrites of MultiDimReductionOp.
// This file implements target-independent rewrites and utilities to lower the
// 'vector.multi_reduction' operation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
@@ -19,6 +20,7 @@
using namespace mlir;
namespace {
/// This file implements the following transformations as composable atomic
/// patterns.
@@ -441,6 +443,7 @@ struct OneDimMultiReductionToTwoDim
return success();
}
};
} // namespace
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,

View File

@@ -0,0 +1,251 @@
//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.scan' operation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "vector-broadcast-lowering"
using namespace mlir;
using namespace mlir::vector;
/// This function constructs the appropriate integer or float
/// operation given the vector combining kind and operands. The
/// supported int operations are : add, mul, min (signed/unsigned),
/// max(signed/unsigned), and, or, xor. The supported float
/// operations are : add, mul, min and max.
static Value genOperator(Location loc, Value x, Value y,
vector::CombiningKind kind,
PatternRewriter &rewriter) {
using vector::CombiningKind;
auto elType = x.getType().cast<VectorType>().getElementType();
bool isInt = elType.isIntOrIndex();
Value combinedResult{nullptr};
switch (kind) {
case CombiningKind::ADD:
if (isInt)
combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
else
combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
break;
case CombiningKind::MUL:
if (isInt)
combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
else
combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
break;
case CombiningKind::MINUI:
combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
break;
case CombiningKind::MINSI:
combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
break;
case CombiningKind::MAXUI:
combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
break;
case CombiningKind::MAXSI:
combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
break;
case CombiningKind::AND:
combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
break;
case CombiningKind::OR:
combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
break;
case CombiningKind::XOR:
combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
break;
case CombiningKind::MINF:
combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
break;
case CombiningKind::MAXF:
combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
break;
}
return combinedResult;
}
/// This function checks to see if the vector combining kind
/// is consistent with the integer or float element type.
static bool isValidKind(bool isInt, vector::CombiningKind kind) {
using vector::CombiningKind;
enum class KindType { FLOAT, INT, INVALID };
KindType type{KindType::INVALID};
switch (kind) {
case CombiningKind::MINF:
case CombiningKind::MAXF:
type = KindType::FLOAT;
break;
case CombiningKind::MINUI:
case CombiningKind::MINSI:
case CombiningKind::MAXUI:
case CombiningKind::MAXSI:
case CombiningKind::AND:
case CombiningKind::OR:
case CombiningKind::XOR:
type = KindType::INT;
break;
case CombiningKind::ADD:
case CombiningKind::MUL:
type = isInt ? KindType::INT : KindType::FLOAT;
break;
}
bool isValidIntKind = (type == KindType::INT) && isInt;
bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
return (isValidIntKind || isValidFloatKind);
}
namespace {
/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
/// vector.extract_strided_slice.
///
/// Example:
///
/// ```
/// %0:2 = vector.scan <add>, %arg0, %arg1
/// {inclusive = true, reduction_dim = 1} :
/// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
/// ```
///
/// is converted to:
///
/// ```
/// %cst = arith.constant dense<0> : vector<2x3xi32>
/// %0 = vector.extract_strided_slice %arg0
/// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
/// : vector<2x3xi32> to vector<2x1xi32>
/// %1 = vector.insert_strided_slice %0, %cst
/// {offsets = [0, 0], strides = [1, 1]}
/// : vector<2x1xi32> into vector<2x3xi32>
/// %2 = vector.extract_strided_slice %arg0
/// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
/// : vector<2x3xi32> to vector<2x1xi32>
/// %3 = arith.muli %0, %2 : vector<2x1xi32>
/// %4 = vector.insert_strided_slice %3, %1
/// {offsets = [0, 1], strides = [1, 1]}
/// : vector<2x1xi32> into vector<2x3xi32>
/// %5 = vector.extract_strided_slice %arg0
/// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
/// : vector<2x3xi32> to vector<2x1xi32>
/// %6 = arith.muli %3, %5 : vector<2x1xi32>
/// %7 = vector.insert_strided_slice %6, %4
/// {offsets = [0, 2], strides = [1, 1]}
/// : vector<2x1xi32> into vector<2x3xi32>
/// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
/// return %7, %8 : vector<2x3xi32>, vector<2xi32>
/// ```
struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ScanOp scanOp,
PatternRewriter &rewriter) const override {
auto loc = scanOp.getLoc();
VectorType destType = scanOp.getDestType();
ArrayRef<int64_t> destShape = destType.getShape();
auto elType = destType.getElementType();
bool isInt = elType.isIntOrIndex();
if (!isValidKind(isInt, scanOp.getKind()))
return failure();
VectorType resType = VectorType::get(destShape, elType);
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
int64_t reductionDim = scanOp.getReductionDim();
bool inclusive = scanOp.getInclusive();
int64_t destRank = destType.getRank();
VectorType initialValueType = scanOp.getInitialValueType();
int64_t initialValueRank = initialValueType.getRank();
SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
reductionShape[reductionDim] = 1;
VectorType reductionType = VectorType::get(reductionShape, elType);
SmallVector<int64_t> offsets(destRank, 0);
SmallVector<int64_t> strides(destRank, 1);
SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
sizes[reductionDim] = 1;
ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
Value lastOutput, lastInput;
for (int i = 0; i < destShape[reductionDim]; i++) {
offsets[reductionDim] = i;
ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
Value input = rewriter.create<vector::ExtractStridedSliceOp>(
loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
scanStrides);
Value output;
if (i == 0) {
if (inclusive) {
output = input;
} else {
if (initialValueRank == 0) {
// ShapeCastOp cannot handle 0-D vectors
output = rewriter.create<vector::BroadcastOp>(
loc, input.getType(), scanOp.getInitialValue());
} else {
output = rewriter.create<vector::ShapeCastOp>(
loc, input.getType(), scanOp.getInitialValue());
}
}
} else {
Value y = inclusive ? input : lastInput;
output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
assert(output != nullptr);
}
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, output, result, offsets, strides);
lastOutput = output;
lastInput = input;
}
Value reduction;
if (initialValueRank == 0) {
Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
reduction =
rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
} else {
reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
lastOutput);
}
rewriter.replaceOp(scanOp, {result, reduction});
return success();
}
};
} // namespace
void mlir::vector::populateVectorScanLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
}

View File

@@ -0,0 +1,177 @@
//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.shape_cast' operation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "vector-shape-cast-lowering"
using namespace mlir;
using namespace mlir::vector;
namespace {
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
/// vectors progressively on the way to target llvm.matrix intrinsics.
/// This iterates over the most major dimension of the 2-D vector and performs
/// rewrites into:
/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
class ShapeCastOp2DDownCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
return failure();
auto loc = op.getLoc();
Value desc = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
desc = rewriter.create<vector::InsertStridedSliceOp>(
loc, vec, desc,
/*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
}
rewriter.replaceOp(op, desc);
return success();
}
};
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
/// vectors progressively.
/// This iterates over the most major dimension of the 2-D vector and performs
/// rewrites into:
/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
class ShapeCastOp2DUpCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
return failure();
auto loc = op.getLoc();
Value desc = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
/*sizes=*/mostMinorVectorSize,
/*strides=*/1);
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
}
rewriter.replaceOp(op, desc);
return success();
}
};
// We typically should not lower general shape cast operations into data
// movement instructions, since the assumption is that these casts are
// optimized away during progressive lowering. For completeness, however,
// we fall back to a reference implementation that moves all elements
// into the right place if we get here.
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
// Special case 2D / 1D lowerings with better implementations.
// TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
int64_t srcRank = sourceVectorType.getRank();
int64_t resRank = resultVectorType.getRank();
if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
return failure();
// Generic ShapeCast lowering path goes all the way down to unrolled scalar
// extract/insert chains.
// TODO: consider evolving the semantics to only allow 1D source or dest and
// drop this potentially very expensive lowering.
// Compute number of elements involved in the reshape.
int64_t numElts = 1;
for (int64_t r = 0; r < srcRank; r++)
numElts *= sourceVectorType.getDimSize(r);
// Replace with data movement operations:
// x[0,0,0] = y[0,0]
// x[0,0,1] = y[0,1]
// x[0,1,0] = y[0,2]
// etc., incrementing the two index vectors "row-major"
// within the source and result shape.
SmallVector<int64_t> srcIdx(srcRank);
SmallVector<int64_t> resIdx(resRank);
Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
for (int64_t i = 0; i < numElts; i++) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType, srcRank - 1);
incIdx(resIdx, resultVectorType, resRank - 1);
}
Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
}
rewriter.replaceOp(op, result);
return success();
}
private:
static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
assert(0 <= r && r < tp.getRank());
if (++idx[r] == tp.getDimSize(r)) {
idx[r] = 0;
incIdx(idx, tp, r - 1);
}
}
};
} // namespace
void mlir::vector::populateVectorShapeCastLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
patterns.getContext(), benefit);
}

View File

@@ -14,7 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Interfaces/VectorInterfaces.h"
using namespace mlir;
@@ -46,6 +46,11 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
}
//===----------------------------------------------------------------------===//
// populateVectorTransferPermutationMapLoweringPatterns
//===----------------------------------------------------------------------===//
namespace {
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identiy +
/// vector.transpose op.
@@ -332,6 +337,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
}
};
} // namespace
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
@@ -339,3 +346,239 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
TransferOpReduceRank, TransferWriteNonPermutationLowering>(
patterns.getContext(), benefit);
}
//===----------------------------------------------------------------------===//
// populateVectorTransferLoweringPatterns
//===----------------------------------------------------------------------===//
namespace {
/// Progressive lowering of transfer_read. This pattern supports lowering of
/// `vector.transfer_read` to a combination of `vector.load` and
/// `vector.broadcast` if all of the following hold:
/// - Stride of most minor memref dimension must be 1.
/// - Out-of-bounds masking is not required.
/// - If the memref's element type is a vector type then it coincides with the
/// result type.
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
struct TransferReadToVectorLoadLowering
: public OpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
maxTransferRank(maxRank) {}
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
return failure();
SmallVector<unsigned> broadcastedDims;
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
// We let the 0-d corner case pass-through as it is supported.
if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
&broadcastedDims))
return failure();
auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
// Non-unit strides are handled by VectorToSCF.
if (!vector::isLastMemrefDimUnitStride(memRefType))
return failure();
// If there is broadcasting involved then we first load the unbroadcasted
// vector, and then broadcast it with `vector.broadcast`.
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
vectorShape.end());
for (unsigned i : broadcastedDims)
unbroadcastedVectorShape[i] = 1;
VectorType unbroadcastedVectorType = VectorType::get(
unbroadcastedVectorShape, read.getVectorType().getElementType());
// `vector.load` supports vector types as memref's elements only when the
// resulting vector type is the same as the element type.
auto memrefElTy = memRefType.getElementType();
if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
return failure();
// Otherwise, element types of the memref and the vector must match.
if (!memrefElTy.isa<VectorType>() &&
memrefElTy != read.getVectorType().getElementType())
return failure();
// Out-of-bounds dims are handled by MaterializeTransferMask.
if (read.hasOutOfBoundsDim())
return failure();
// Create vector load op.
Operation *loadOp;
if (read.getMask()) {
Value fill = rewriter.create<vector::SplatOp>(
read.getLoc(), unbroadcastedVectorType, read.getPadding());
loadOp = rewriter.create<vector::MaskedLoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices(), read.getMask(), fill);
} else {
loadOp = rewriter.create<vector::LoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices());
}
// Insert a broadcasting op if required.
if (!broadcastedDims.empty()) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
read, read.getVectorType(), loadOp->getResult(0));
} else {
rewriter.replaceOp(read, loadOp->getResult(0));
}
return success();
}
std::optional<unsigned> maxTransferRank;
};
/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
// TODO: we shouldn't cross the vector/scalar domains just for this
// but atm we lack the infra to avoid it. Possible solutions include:
// - go directly to LLVM + bitcast
// - introduce a bitcast op and likely a new pointer dialect
// - let memref.load/store additionally support the 0-d vector case
// There are still deeper data layout issues lingering even in this
// trivial case (for architectures for which this matters).
struct VectorLoadToMemrefLoadLowering
: public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
auto vecType = loadOp.getVectorType();
if (vecType.getNumElements() != 1)
return failure();
auto memrefLoad = rewriter.create<memref::LoadOp>(
loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
memrefLoad);
return success();
}
};
/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
struct VectorStoreToMemrefStoreLowering
: public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const override {
auto vecType = storeOp.getVectorType();
if (vecType.getNumElements() != 1)
return failure();
Value extracted;
if (vecType.getRank() == 0) {
// TODO: Unifiy once ExtractOp supports 0-d vectors.
extracted = rewriter.create<vector::ExtractElementOp>(
storeOp.getLoc(), storeOp.getValueToStore());
} else {
SmallVector<int64_t> indices(vecType.getRank(), 0);
extracted = rewriter.create<vector::ExtractOp>(
storeOp.getLoc(), storeOp.getValueToStore(), indices);
}
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
return success();
}
};
/// Progressive lowering of transfer_write. This pattern supports lowering of
/// `vector.transfer_write` to `vector.store` if all of the following hold:
/// - Stride of most minor memref dimension must be 1.
/// - Out-of-bounds masking is not required.
/// - If the memref's element type is a vector type then it coincides with the
/// type of the written value.
/// - The permutation map is the minor identity map (neither permutation nor
/// broadcasting is allowed).
struct TransferWriteToVectorStoreLowering
: public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
maxTransferRank(maxRank) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "rank exceeds maxTransferRank: " << write;
});
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
if ( // pass-through for the 0-d corner case.
!write.getPermutationMap().isMinorIdentity())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "permutation map is not minor identity: " << write;
});
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "not a memref type: " << write;
});
// Non-unit strides are handled by VectorToSCF.
if (!vector::isLastMemrefDimUnitStride(memRefType))
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "most minor stride is not 1: " << write;
});
// `vector.store` supports vector types as memref's elements only when the
// type of the vector value being written is the same as the element type.
auto memrefElTy = memRefType.getElementType();
if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
});
// Otherwise, element types of the memref and the vector must match.
if (!memrefElTy.isa<VectorType>() &&
memrefElTy != write.getVectorType().getElementType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
});
// Out-of-bounds dims are handled by MaterializeTransferMask.
if (write.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "out of bounds dim: " << write;
});
if (write.getMask()) {
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
write, write.getSource(), write.getIndices(), write.getMask(),
write.getVector());
} else {
rewriter.replaceOpWithNewOp<vector::StoreOp>(
write, write.getVector(), write.getSource(), write.getIndices());
}
return success();
}
std::optional<unsigned> maxTransferRank;
};
} // namespace
void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
PatternBenefit benefit) {
patterns.add<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(patterns.getContext(),
maxTransferRank, benefit);
patterns
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
patterns.getContext(), benefit);
}

View File

@@ -0,0 +1,210 @@
//===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.transpose' operation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "vector-shape-cast-lowering"
using namespace mlir;
using namespace mlir::vector;
/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
/// transposed.
static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
SmallVectorImpl<int64_t> &result) {
size_t numTransposedDims = transpose.size();
for (size_t transpDim : llvm::reverse(transpose)) {
if (transpDim != numTransposedDims - 1)
break;
numTransposedDims--;
}
result.append(transpose.begin(), transpose.begin() + numTransposedDims);
}
namespace {
/// Progressive lowering of TransposeOp.
/// One:
/// %x = vector.transpose %y, [1, 0]
/// is replaced by:
/// %z = arith.constant dense<0.000000e+00>
/// %0 = vector.extract %y[0, 0]
/// %1 = vector.insert %0, %z [0, 0]
/// ..
/// %x = vector.insert .., .. [.., ..]
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value input = op.getVector();
VectorType inputType = op.getSourceVectorType();
VectorType resType = op.getResultVectorType();
// Set up convenience transposition table.
SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Shuffle &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
// Handle a true 2-D matrix transpose differently when requested.
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
Type flattenedType =
VectorType::get(resType.getNumElements(), resType.getElementType());
auto matrix =
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
Value trans = rewriter.create<vector::FlatTransposeOp>(
loc, flattenedType, matrix, rows, columns);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
return success();
}
// Generate unrolled extract/insert ops. We do not unroll the rightmost
// (i.e., highest-order) dimensions that are not transposed and leave them
// in vector form to improve performance. Therefore, we prune those
// dimensions from the shape/transpose data structures used to generate the
// extract/insert ops.
SmallVector<int64_t> prunedTransp;
pruneNonTransposedDims(transp, prunedTransp);
size_t numPrunedDims = transp.size() - prunedTransp.size();
auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
auto prunedInStrides = computeStrides(prunedInShape);
// Generates the extract/insert operations for every scalar/vector element
// of the leftmost transposed dimensions. We traverse every transpose
// element using a linearized index that we delinearize to generate the
// appropriate indices for the extract/insert operations.
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
++linearIdx) {
auto extractIdxs = delinearize(linearIdx, prunedInStrides);
SmallVector<int64_t> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
result =
rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
}
rewriter.replaceOp(op, result);
return success();
}
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
};
/// Rewrite a 2-D vector.transpose as a sequence of:
/// vector.shape_cast 2D -> 1D
/// vector.shuffle
/// vector.shape_cast 1D -> 2D
class TransposeOp2DToShuffleLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
TransposeOp2DToShuffleLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType srcType = op.getSourceVectorType();
if (srcType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (transp[0] != 1 && transp[1] != 0)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
if (vectorTransformOptions.vectorTransposeLowering !=
VectorTransposeLowering::Shuffle)
return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
Value casted = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get({m * n}, srcType.getElementType()),
op.getVector());
SmallVector<int64_t> mask;
mask.reserve(m * n);
for (int64_t j = 0; j < n; ++j)
for (int64_t i = 0; i < m; ++i)
mask.push_back(i * n + j);
Value shuffled =
rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
op, op.getResultVectorType(), shuffled);
return success();
}
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
};
} // namespace
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext(), benefit);
}

View File

@@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinOps.h"

View File

@@ -11,8 +11,8 @@
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include <optional>
#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -92,11 +92,11 @@ static Value createInBoundsCond(RewriterBase &b,
}
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fastpath and a slowpath.
/// masking) fast path and a slow path.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
/// newly created conditional upon function return.
/// To accomodate for the fact that the original vector.transfer indexing may be
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// To accommodate for the fact that the original vector.transfer indexing may
/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// scf.if op returns a view and values of type index.
/// At this time, only vector.transfer_read case is implemented.
///
@@ -107,11 +107,11 @@ static Value createInBoundsCond(RewriterBase &b,
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// // fastpath, direct cast
/// // fast path, direct cast
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// // slowpath, not in-bounds vector.transfer or linalg.copy.
/// // slow path, not in-bounds vector.transfer or linalg.copy.
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
@@ -172,12 +172,10 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
resShape[idx] =
(aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
resStrides[idx] = (aStrides[idx] == bStrides[idx])
? aStrides[idx]
: ShapedType::kDynamic;
resStrides[idx] =
(aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
}
resOffset =
(aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
return MemRefType::get(
resShape, aT.getElementType(),
StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
@@ -634,7 +632,34 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
return success();
}
LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
namespace {
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
/// may take an extra filter to perform selection at a finer granularity.
struct VectorTransferFullPartialRewriter : public RewritePattern {
using FilterConstraintType =
std::function<LogicalResult(VectorTransferOpInterface op)>;
explicit VectorTransferFullPartialRewriter(
MLIRContext *context,
VectorTransformsOptions options = VectorTransformsOptions(),
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
filter(std::move(filter)) {}
/// Performs the rewrite.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
VectorTransformsOptions options;
FilterConstraintType filter;
};
} // namespace
LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
@@ -642,3 +667,9 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
return failure();
return splitFullAndPartialTransfer(rewriter, xferOp, options);
}
void mlir::vector::populateVectorTransferFullPartialPatterns(
RewritePatternSet &patterns, const VectorTransformsOptions &options) {
patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
options);
}

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@@ -148,8 +149,9 @@ struct TestVectorContractionLowering
if (lowerToOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.add<ContractionOpToOuterProductOpLowering>(options,
&getContext());
populateVectorContractLoweringPatterns(
patterns, options, /*benefit=*/1,
/*disableOuterProductlowering=*/true);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
@@ -469,7 +471,7 @@ struct TestVectorTransferFullPartialSplitPatterns
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
else
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
populateVectorTransferFullPartialPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

View File

@@ -8539,6 +8539,7 @@ cc_library(
":TransformDialect",
":TransformDialectUtils",
":TransformUtils",
":VectorTransforms",
"//llvm:Support",
],
)