mirror of
https://github.com/intel/llvm.git
synced 2026-01-23 07:58:23 +08:00
[mlir][Vector] NFC - Reorganize vector patterns
Vector dialect patterns have grown enormously in the past year to a point where they are now impenetrable. Start reorganizing them towards finer-grained control. Differential Revision: https://reviews.llvm.org/D146736
This commit is contained in:
@@ -110,43 +110,11 @@ void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
|
||||
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collect a set of transfer read/write lowering patterns.
|
||||
///
|
||||
/// These patterns lower transfer ops to simpler ops like `vector.load`,
|
||||
/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
|
||||
/// of a most `maxTransferRank` are lowered. This is useful when combined with
|
||||
/// VectorToSCF, which reduces the rank of vector transfer ops.
|
||||
void populateVectorTransferLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
std::optional<unsigned> maxTransferRank = std::nullopt,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// These patterns materialize masks for various vector ops such as transfers.
|
||||
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
|
||||
bool force32BitVectorIndices,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collects patterns to progressively lower vector.broadcast ops on high-D
|
||||
/// vectors to low-D vector ops.
|
||||
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collects patterns to progressively lower vector mask ops into elementary
|
||||
/// selection and insertion ops.
|
||||
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collects patterns to progressively lower vector.shape_cast ops on high-D
|
||||
/// vectors into 1-D/2-D vector ops by generating data movement extract/insert
|
||||
/// ops.
|
||||
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collects patterns that lower scalar vector transfer ops to memref loads and
|
||||
/// stores when beneficial.
|
||||
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Returns the integer type required for subscripts in the vector dialect.
|
||||
IntegerType getVectorSubscriptType(Builder &builder);
|
||||
|
||||
@@ -214,8 +182,8 @@ void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
|
||||
/// Creates a vector.mask operation around a maskable operation. Returns the
|
||||
/// vector.mask operation if the mask provided is valid. Otherwise, returns the
|
||||
/// maskable operation itself.
|
||||
Operation *maskOperation(OpBuilder &builder, Operation *maskableOp,
|
||||
Value mask, Value passthru = Value());
|
||||
Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask,
|
||||
Value passthru = Value());
|
||||
|
||||
/// Creates a vector select operation that picks values from `newValue` or
|
||||
/// `passthru` for each result vector lane based on `mask`. This utility is used
|
||||
|
||||
248
mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
Normal file
248
mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
Normal file
@@ -0,0 +1,248 @@
|
||||
//===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
|
||||
#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
|
||||
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
|
||||
namespace mlir {
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace vector {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Lowering pattern populate functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [OuterProductOpLowering]
|
||||
/// Progressively lower a `vector.outerproduct` to linearized
|
||||
/// `vector.extract` + `vector.fma` + `vector.insert`.
|
||||
///
|
||||
/// [ContractionOpLowering]
|
||||
/// Progressive lowering of ContractionOp.
|
||||
/// One:
|
||||
/// %x = vector.contract with at least one free/batch dimension
|
||||
/// is replaced by:
|
||||
/// %a = vector.contract with one less free/batch dimension
|
||||
/// %b = vector.contract with one less free/batch dimension
|
||||
///
|
||||
/// [ContractionOpToMatmulOpLowering]
|
||||
/// Progressively lower a `vector.contract` with row-major matmul semantics to
|
||||
/// linearized `vector.shape_cast` + `vector.matmul` on the way to
|
||||
/// `llvm.matrix.multiply`.
|
||||
///
|
||||
/// [ContractionOpToDotLowering]
|
||||
/// Progressively lower a `vector.contract` with row-major matmul semantics to
|
||||
/// linearized `vector.extract` + `vector.reduce` + `vector.insert`.
|
||||
///
|
||||
/// [ContractionOpToOuterProductOpLowering]
|
||||
/// Progressively lower a `vector.contract` with row-major matmul semantics to
|
||||
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
|
||||
void populateVectorContractLoweringPatterns(
|
||||
RewritePatternSet &patterns, VectorTransformsOptions options,
|
||||
PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
|
||||
|
||||
/// Collect a set of patterns to convert vector.multi_reduction op into
|
||||
/// a sequence of vector.reduction ops. The patterns comprise:
|
||||
///
|
||||
/// [InnerOuterDimReductionConversion]
|
||||
/// Rewrites vector.multi_reduction such that all reduction dimensions are
|
||||
/// either innermost or outermost, by adding the proper vector.transpose
|
||||
/// operations.
|
||||
///
|
||||
/// [ReduceMultiDimReductionRank]
|
||||
/// Once in innermost or outermost reduction
|
||||
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
|
||||
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
|
||||
/// back.
|
||||
///
|
||||
/// [TwoDimMultiReductionToElementWise]
|
||||
/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
|
||||
/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
|
||||
/// ops. This also has an opportunity for tree-reduction (in the future).
|
||||
///
|
||||
/// [TwoDimMultiReductionToReduction]
|
||||
/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
|
||||
/// dimension, unroll the outer dimension to obtain a sequence of extract +
|
||||
/// vector.reduction + insert. This can further lower to horizontal reduction
|
||||
/// ops.
|
||||
///
|
||||
/// [OneDimMultiReductionToTwoDim]
|
||||
/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
|
||||
/// either a parallel or a reduction), we lift them back up to 2-D with a simple
|
||||
/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
|
||||
/// thus fully exiting out of the vector.multi_reduction abstraction.
|
||||
void populateVectorMultiReductionLoweringPatterns(
|
||||
RewritePatternSet &patterns, VectorMultiReductionLowering options,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [TransferReadToVectorLoadLowering]
|
||||
/// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D
|
||||
/// BroadcastOp until dim 1.
|
||||
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [CreateMaskOp]
|
||||
/// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1.
|
||||
///
|
||||
/// [ConstantMaskOp]
|
||||
/// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until
|
||||
/// dim 1.
|
||||
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collects patterns that lower scalar vector transfer ops to memref loads and
|
||||
/// stores when beneficial.
|
||||
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [ShapeCastOp2DDownCastRewritePattern]
|
||||
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
|
||||
/// vectors progressively.
|
||||
///
|
||||
/// [ShapeCastOp2DUpCastRewritePattern]
|
||||
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
|
||||
/// vectors progressively.
|
||||
///
|
||||
/// [ShapeCastOpRewritePattern]
|
||||
/// Reference lowering to fully unrolled sequences of single element ExtractOp +
|
||||
/// InsertOp. Note that applying this pattern can almost always be considered a
|
||||
/// performance bug.
|
||||
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [TransposeOpLowering]
|
||||
///
|
||||
/// [TransposeOp2DToShuffleLowering]
|
||||
///
|
||||
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
|
||||
VectorTransformsOptions options,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [TransferReadToVectorLoadLowering]
|
||||
/// Progressive lowering of transfer_read.This pattern supports lowering of
|
||||
/// `vector.transfer_read` to a combination of `vector.load` and
|
||||
/// `vector.broadcast`
|
||||
///
|
||||
/// [TransferWriteToVectorStoreLowering]
|
||||
/// Progressive lowering of transfer_write. This pattern supports lowering of
|
||||
/// `vector.transfer_write` to `vector.store`
|
||||
///
|
||||
/// [VectorLoadToMemrefLoadLowering]
|
||||
/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
|
||||
///
|
||||
/// [VectorStoreToMemrefStoreLowering]
|
||||
/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
|
||||
///
|
||||
/// These patterns lower transfer ops to simpler ops like `vector.load`,
|
||||
/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
|
||||
/// of a most `maxTransferRank` are lowered. This is useful when combined with
|
||||
/// VectorToSCF, which reduces the rank of vector transfer ops.
|
||||
void populateVectorTransferLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
std::optional<unsigned> maxTransferRank = std::nullopt,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collect a set of transfer read/write lowering patterns that simplify the
|
||||
/// permutation map (e.g., converting it to a minor identity map) by inserting
|
||||
/// broadcasts and transposes. More specifically:
|
||||
///
|
||||
/// [TransferReadPermutationLowering]
|
||||
/// Lower transfer_read op with permutation into a transfer_read with a
|
||||
/// permutation map composed of leading zeros followed by a minor identity +
|
||||
/// vector.transpose op.
|
||||
/// Ex:
|
||||
/// vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2) -> (0, d1)
|
||||
/// into:
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2) -> (d1, 0)
|
||||
/// vector.transpose %v, [1, 0]
|
||||
///
|
||||
/// vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
|
||||
/// into:
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
|
||||
/// vector.transpose %v, [0, 1, 3, 2, 4]
|
||||
/// Note that an alternative is to transform it to linalg.transpose +
|
||||
/// vector.transfer_read to do the transpose in memory instead.
|
||||
///
|
||||
/// [TransferWritePermutationLowering]
|
||||
/// Lower transfer_write op with permutation into a transfer_write with a
|
||||
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
|
||||
/// Ex:
|
||||
/// vector.transfer_write %v ...
|
||||
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
|
||||
/// into:
|
||||
/// %tmp = vector.transpose %v, [2, 0, 1]
|
||||
/// vector.transfer_write %tmp ...
|
||||
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
|
||||
///
|
||||
/// vector.transfer_write %v ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
|
||||
/// into:
|
||||
/// %tmp = vector.transpose %v, [1, 0]
|
||||
/// %v = vector.transfer_write %tmp ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
|
||||
///
|
||||
/// [TransferOpReduceRank]
|
||||
/// Lower transfer_read op with broadcast in the leading dimensions into
|
||||
/// transfer_read of lower rank + vector.broadcast.
|
||||
/// Ex: vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
|
||||
/// into:
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
|
||||
/// vector.broadcast %v
|
||||
void populateVectorTransferPermutationMapLoweringPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [ScanToArithOps]
|
||||
/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
|
||||
/// vector.extract_strided_slice.
|
||||
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [FlattenGather]
|
||||
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
|
||||
/// outermost dimension. For example:
|
||||
///
|
||||
/// [Gather1DToConditionalLoads]
|
||||
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
|
||||
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
|
||||
/// loads/extracts are made conditional using `scf.if` ops.
|
||||
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
|
||||
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
|
||||
/// not its nested `MaskableOpInterface`.
|
||||
void populateVectorMaskLoweringPatternsForSideEffectingOps(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace vector
|
||||
} // namespace mlir
|
||||
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
|
||||
@@ -22,12 +22,6 @@ std::unique_ptr<Pass> createVectorBufferizePass();
|
||||
/// Creates an instance of the `vector.mask` lowering pass.
|
||||
std::unique_ptr<Pass> createLowerVectorMaskPass();
|
||||
|
||||
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
|
||||
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
|
||||
/// not its nested `MaskableOpInterface`.
|
||||
void populateVectorMaskLoweringPatternsForSideEffectingOps(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -9,8 +9,8 @@
|
||||
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
|
||||
#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
|
||||
|
||||
#include <utility>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
|
||||
@@ -23,42 +23,7 @@ namespace mlir {
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace vector {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Vector transformation options exposed as auxiliary structs.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Structure to control the behavior of vector transform patterns.
|
||||
struct VectorTransformsOptions {
|
||||
/// Option to control the lowering of vector.contract.
|
||||
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransformsOptions(VectorContractLowering opt) {
|
||||
vectorContractLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the lowering of vector.multi_reduction.
|
||||
VectorMultiReductionLowering vectorMultiReductionLowering =
|
||||
VectorMultiReductionLowering::InnerParallel;
|
||||
VectorTransformsOptions &
|
||||
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
|
||||
vectorMultiReductionLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the lowering of vector.transpose.
|
||||
VectorTransposeLowering vectorTransposeLowering =
|
||||
VectorTransposeLowering::EltWise;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransposeLowering(VectorTransposeLowering opt) {
|
||||
vectorTransposeLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the splitting of vector transfers.
|
||||
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
|
||||
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
|
||||
vectorTransferSplit = opt;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
struct VectorTransformsOptions;
|
||||
|
||||
/// Options that control the vector unrolling.
|
||||
struct UnrollVectorOptions {
|
||||
@@ -109,45 +74,6 @@ struct UnrollVectorOptions {
|
||||
// Vector transformation exposed as populate functions over rewrite patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Insert TransposeLowering patterns into extraction/insertion.
|
||||
void populateVectorTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collect a set of patterns to convert vector.multi_reduction op into
|
||||
/// a sequence of vector.reduction ops. The patterns comprise:
|
||||
/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
|
||||
/// that all reduction dimensions are either innermost or outermost, by adding
|
||||
/// the proper vector.transpose operations.
|
||||
/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
|
||||
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
|
||||
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
|
||||
/// back.
|
||||
/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
|
||||
/// form, with an **outermost** reduction dimension, unroll the outer dimension
|
||||
/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
|
||||
/// tree-reduction (in the future).
|
||||
/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
|
||||
/// with an **innermost** reduction dimension, unroll the outer dimension to
|
||||
/// obtain a sequence of extract + vector.reduction + insert. This can further
|
||||
/// lower to horizontal reduction ops.
|
||||
/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
|
||||
/// reduction (and are thus missing either a parallel or a reduction), we lift
|
||||
/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
|
||||
/// the other patterns can kick in, thus fully exiting out of the
|
||||
/// vector.multi_reduction abstraction.
|
||||
void populateVectorMultiReductionLoweringPatterns(
|
||||
RewritePatternSet &patterns, VectorMultiReductionLowering options,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collects patterns to progressively lower vector contraction ops on high-D
|
||||
/// into low-D reduction and product ops.
|
||||
void populateVectorContractLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
|
||||
/// semantics to a contraction with MMT semantics (matrix matrix multiplication
|
||||
/// with the RHS transposed). This specific form is meant to have the vector
|
||||
@@ -174,67 +100,43 @@ void populateVectorContractCanonicalizeMatmulToMMT(
|
||||
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Collect patterns to convert scan op
|
||||
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
/// Populate `patterns` with the following patterns.
|
||||
///
|
||||
/// - VectorTransferFullPartialRewriter
|
||||
///
|
||||
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
|
||||
/// masking) fast path and a slow path.
|
||||
///
|
||||
/// Example (a 2-D vector.transfer_read):
|
||||
/// ```
|
||||
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
|
||||
/// ```
|
||||
/// is transformed into:
|
||||
/// ```
|
||||
/// %1:3 = scf.if (%inBounds) {
|
||||
/// // fast path, direct cast
|
||||
/// memref.cast %A: memref<A...> to compatibleMemRefType
|
||||
/// scf.yield %view : compatibleMemRefType, index, index
|
||||
/// } else {
|
||||
/// // slow path, not in-bounds vector.transfer or linalg.copy.
|
||||
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// scf.yield %4 : compatibleMemRefType, index, index
|
||||
// }
|
||||
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
|
||||
/// ```
|
||||
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
|
||||
///
|
||||
/// Preconditions:
|
||||
/// 1. `xferOp.permutation_map()` must be a minor identity map
|
||||
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
|
||||
/// must be equal. This will be relaxed in the future but requires
|
||||
/// rank-reducing subviews.
|
||||
void populateVectorTransferFullPartialPatterns(
|
||||
RewritePatternSet &patterns, const VectorTransformsOptions &options);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Vector.transfer patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Collect a set of transfer read/write lowering patterns that simplify the
|
||||
/// permutation map (e.g., converting it to a minor identity map) by inserting
|
||||
/// broadcasts and transposes. More specifically:
|
||||
///
|
||||
/// [TransferReadPermutationLowering]
|
||||
/// Lower transfer_read op with permutation into a transfer_read with a
|
||||
/// permutation map composed of leading zeros followed by a minor identity +
|
||||
/// vector.transpose op.
|
||||
/// Ex:
|
||||
/// vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2) -> (0, d1)
|
||||
/// into:
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2) -> (d1, 0)
|
||||
/// vector.transpose %v, [1, 0]
|
||||
///
|
||||
/// vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
|
||||
/// into:
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
|
||||
/// vector.transpose %v, [0, 1, 3, 2, 4]
|
||||
/// Note that an alternative is to transform it to linalg.transpose +
|
||||
/// vector.transfer_read to do the transpose in memory instead.
|
||||
///
|
||||
/// [TransferWritePermutationLowering]
|
||||
/// Lower transfer_write op with permutation into a transfer_write with a
|
||||
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
|
||||
/// Ex:
|
||||
/// vector.transfer_write %v ...
|
||||
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
|
||||
/// into:
|
||||
/// %tmp = vector.transpose %v, [2, 0, 1]
|
||||
/// vector.transfer_write %tmp ...
|
||||
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
|
||||
///
|
||||
/// vector.transfer_write %v ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
|
||||
/// into:
|
||||
/// %tmp = vector.transpose %v, [1, 0]
|
||||
/// %v = vector.transfer_write %tmp ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
|
||||
///
|
||||
/// [TransferOpReduceRank]
|
||||
/// Lower transfer_read op with broadcast in the leading dimensions into
|
||||
/// transfer_read of lower rank + vector.broadcast.
|
||||
/// Ex: vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
|
||||
/// into:
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
|
||||
/// vector.broadcast %v
|
||||
void populateVectorTransferPermutationMapLoweringPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit = 1);
|
||||
|
||||
/// Collect a set of patterns to reduce the rank of the operands of vector
|
||||
/// transfer ops to operate on the largest contigious vector.
|
||||
@@ -334,220 +236,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
|
||||
const UnrollVectorOptions &options,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Expands `vector.gather` ops into a series of conditional scalar loads
|
||||
/// (`vector.load` for memrefs or `tensor.extract` for tensors). These loads are
|
||||
/// conditional to avoid out-of-bounds memory accesses and guarded with `scf.if`
|
||||
/// ops. This lowering path is intended for targets that do not feature
|
||||
/// dedicated gather ops.
|
||||
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Finer-grained patterns exposed for more control over individual lowerings.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
|
||||
/// may take an extra filter to perform selection at a finer granularity.
|
||||
struct VectorTransferFullPartialRewriter : public RewritePattern {
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(VectorTransferOpInterface op)>;
|
||||
|
||||
explicit VectorTransferFullPartialRewriter(
|
||||
MLIRContext *context,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
FilterConstraintType filter =
|
||||
[](VectorTransferOpInterface op) { return success(); },
|
||||
PatternBenefit benefit = 1)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
|
||||
filter(std::move(filter)) {}
|
||||
|
||||
/// Performs the rewrite.
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
VectorTransformsOptions options;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to:
|
||||
/// ```
|
||||
/// %flattened_a = vector.shape_cast %a
|
||||
/// %flattened_b = vector.shape_cast %b
|
||||
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
|
||||
/// %d = vector.shape_cast %%flattened_d
|
||||
/// %e = add %c, %d
|
||||
/// ```
|
||||
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
|
||||
//
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
|
||||
/// the vector.contract op is a row-major matrix multiply.
|
||||
class ContractionOpToMatmulOpLowering
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpToMatmulOpLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, PatternBenefit benefit = 1,
|
||||
FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context, benefit),
|
||||
vectorTransformOptions(vectorTransformOptions),
|
||||
filter(std::move(constraint)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to a reduction_size-unrolled sequence:
|
||||
/// ```
|
||||
/// %at = vector.transpose %a, [1, 0]
|
||||
/// %bRow0 = vector.extract %b[0]
|
||||
/// %atRow0 = vector.extract %at[0]
|
||||
/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
|
||||
/// ...
|
||||
/// %bRowK = vector.extract %b[K]
|
||||
/// %atRowK = vector.extract %at[K]
|
||||
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
|
||||
/// ```
|
||||
///
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
|
||||
/// the vector.contract op is a row-major matrix multiply.
|
||||
class ContractionOpToOuterProductOpLowering
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpToOuterProductOpLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, PatternBenefit benefit = 1,
|
||||
FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context, benefit),
|
||||
vectorTransformOptions(vectorTransformOptions),
|
||||
filter(std::move(constraint)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to an output-size-unrolled sequence:
|
||||
/// ```
|
||||
/// %out = arith.constant ... : vector<MxNxelt_type>
|
||||
/// %bt = vector.transpose %b, [1, 0]
|
||||
/// %aRow0 = vector.extract %a[0]
|
||||
/// %btRow0 = vector.extract %bt[0]
|
||||
/// %c00 = vector.reduce %atRow0, %bRow0
|
||||
/// %out00 = vector.insert %c00, %out[0, 0]
|
||||
/// ...
|
||||
/// %aRowLast = vector.extract %at[M-1]
|
||||
/// %btRowLast = vector.extract %b[N-1]
|
||||
/// %cLastLast = vector.reduce %atRowLast, %bRowLast
|
||||
/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
|
||||
/// ```
|
||||
///
|
||||
/// This only kicks in when VectorTransformsOptions is set to Dot and
|
||||
/// the vector.contract op is a row-major matmul or matvec.
|
||||
class ContractionOpToDotLowering
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpToDotLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, PatternBenefit benefit = 1,
|
||||
const FilterConstraintType &constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context, benefit),
|
||||
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of ContractionOp.
|
||||
///
|
||||
/// One:
|
||||
/// %x = vector.contract with at least one free/batch dimension
|
||||
/// is replaced by:
|
||||
/// %a = vector.contract with one less free/batch dimension
|
||||
/// %b = vector.contract with one less free/batch dimension
|
||||
/// ..
|
||||
/// %x = combine %a %b ..
|
||||
/// until a pure contraction is reached (no free/batch dimensions),
|
||||
/// which is replaced by a dot-product.
|
||||
///
|
||||
/// This only kicks in when either VectorTransformsOptions is set
|
||||
/// to Dot or when other contraction patterns fail.
|
||||
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, PatternBenefit benefit = 1,
|
||||
FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context, benefit),
|
||||
vectorTransformOptions(vectorTransformOptions),
|
||||
filter(std::move(constraint)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
// Lower one parallel dimension.
|
||||
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
|
||||
vector::ContractionOp op, int64_t lhsIndex,
|
||||
int64_t rhsIndex, Value mask) const;
|
||||
// Lower one reduction dimension.
|
||||
FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
|
||||
vector::ContractionOp op, Value mask) const;
|
||||
};
|
||||
|
||||
} // namespace vector
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -24,17 +24,53 @@ class IfOp;
|
||||
|
||||
namespace vector {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Vector transformation options exposed as auxiliary structs.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Structure to control the behavior of vector transform patterns.
|
||||
struct VectorTransformsOptions {
|
||||
/// Option to control the lowering of vector.contract.
|
||||
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransformsOptions(VectorContractLowering opt) {
|
||||
vectorContractLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the lowering of vector.multi_reduction.
|
||||
VectorMultiReductionLowering vectorMultiReductionLowering =
|
||||
VectorMultiReductionLowering::InnerParallel;
|
||||
VectorTransformsOptions &
|
||||
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
|
||||
vectorMultiReductionLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the lowering of vector.transpose.
|
||||
VectorTransposeLowering vectorTransposeLowering =
|
||||
VectorTransposeLowering::EltWise;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransposeLowering(VectorTransposeLowering opt) {
|
||||
vectorTransposeLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the splitting of vector transfers.
|
||||
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
|
||||
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
|
||||
vectorTransferSplit = opt;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Standalone transformations and helpers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
|
||||
/// masking) fastpath and a slowpath.
|
||||
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
|
||||
/// newly created conditional upon function return.
|
||||
/// To accomodate for the fact that the original vector.transfer indexing may be
|
||||
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
|
||||
/// scf.if op returns a view and values of type index.
|
||||
/// At this time, only vector.transfer_read case is implemented.
|
||||
/// Split a vector.transfer operation into an in-bounds (i.e., no
|
||||
/// out-of-bounds masking) fastpath and a slowpath. If `ifOp` is not null and
|
||||
/// the result is `success, the `ifOp` points to the newly created conditional
|
||||
/// upon function return. To accomodate for the fact that the original
|
||||
/// vector.transfer indexing may be arbitrary and the slow path indexes
|
||||
/// @[0...0] in the temporary buffer, the scf.if op returns a view and values
|
||||
/// of type index. At this time, only vector.transfer_read case is
|
||||
/// implemented.
|
||||
///
|
||||
/// Example (a 2-D vector.transfer_read):
|
||||
/// ```
|
||||
@@ -51,15 +87,16 @@ namespace vector {
|
||||
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// scf.yield %4 : compatibleMemRefType, index, index
|
||||
// }
|
||||
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
|
||||
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ...
|
||||
/// true]}
|
||||
/// ```
|
||||
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
|
||||
///
|
||||
/// Preconditions:
|
||||
/// 1. `xferOp.permutation_map()` must be a minor identity map
|
||||
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
|
||||
/// must be equal. This will be relaxed in the future but requires
|
||||
/// rank-reducing subviews.
|
||||
/// 2. the rank of the `xferOp.memref()` and the rank of the
|
||||
/// `xferOp.vector()` must be equal. This will be relaxed in the future but
|
||||
/// requires rank-reducing subviews.
|
||||
LogicalResult splitFullAndPartialTransfer(
|
||||
RewriterBase &b, VectorTransferOpInterface xferOp,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
||||
@@ -64,10 +65,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorToVectorCanonicalizationPatterns(patterns);
|
||||
populateVectorBroadcastLoweringPatterns(patterns);
|
||||
populateVectorContractLoweringPatterns(patterns);
|
||||
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
|
||||
populateVectorMaskOpLoweringPatterns(patterns);
|
||||
populateVectorShapeCastLoweringPatterns(patterns);
|
||||
populateVectorTransposeLoweringPatterns(patterns);
|
||||
populateVectorTransposeLoweringPatterns(patterns,
|
||||
VectorTransformsOptions());
|
||||
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
|
||||
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <type_traits>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
|
||||
@@ -20,5 +20,5 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRTransformDialect
|
||||
MLIRTransformDialectUtils
|
||||
MLIRVectorDialect
|
||||
MLIRVectorTransforms
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
||||
#include "mlir/Dialect/Transform/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
@@ -7,13 +7,14 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
|
||||
|
||||
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDL.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
@@ -82,10 +83,9 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
|
||||
|
||||
// In the future we may want to more finely select particular stages.
|
||||
// Stage 1: contraction lowerings.
|
||||
patterns.add<mlir::vector::ContractionOpToOuterProductOpLowering,
|
||||
mlir::vector::ContractionOpToMatmulOpLowering,
|
||||
mlir::vector::ContractionOpLowering>(vectorTransformOptions,
|
||||
ctx);
|
||||
populateVectorContractLoweringPatterns(
|
||||
patterns, vectorTransformOptions, /*benefit=*/1,
|
||||
/*disableOuterProductLowering*/ true);
|
||||
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
|
||||
|
||||
// Stage 2: multi-reduction lowerings.
|
||||
@@ -93,8 +93,7 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
|
||||
patterns, vectorTransformOptions.vectorMultiReductionLowering);
|
||||
|
||||
// Stage 3: Rewrite vector.transfer into full and partial parts.
|
||||
patterns.add<vector::VectorTransferFullPartialRewriter>(
|
||||
ctx, vectorTransformOptions);
|
||||
populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
|
||||
|
||||
// Stage 4: Lower vector transfers.
|
||||
vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank);
|
||||
@@ -107,8 +106,8 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
|
||||
vector::populateVectorShapeCastLoweringPatterns(patterns);
|
||||
|
||||
// Stage 7: Lower vector.transpose.
|
||||
vector::populateVectorTransposeLoweringPatterns(patterns,
|
||||
vectorTransformOptions);
|
||||
vector::populateVectorTransposeLoweringPatterns(
|
||||
patterns, vectorTransformOptions, /*benefit=*/1);
|
||||
if (getTransposeAvx2Lowering())
|
||||
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
|
||||
patterns, avx2LoweringOptions, /*benefit=*/10);
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
add_mlir_dialect_library(MLIRVectorTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
LowerVectorBroadcast.cpp
|
||||
LowerVectorContract.cpp
|
||||
LowerVectorGather.cpp
|
||||
LowerVectorMask.cpp
|
||||
LowerVectorMultiReduction.cpp
|
||||
LowerVectorScan.cpp
|
||||
LowerVectorShapeCast.cpp
|
||||
LowerVectorTransfer.cpp
|
||||
LowerVectorTranspose.cpp
|
||||
VectorDistribute.cpp
|
||||
VectorDropLeadUnitDim.cpp
|
||||
VectorInsertExtractStridedSliceRewritePatterns.cpp
|
||||
VectorMultiDimReductionTransforms.cpp
|
||||
VectorTransferOpTransforms.cpp
|
||||
VectorTransferSplitRewritePatterns.cpp
|
||||
VectorTransferPermutationMapRewritePatterns.cpp
|
||||
VectorTransforms.cpp
|
||||
VectorUnroll.cpp
|
||||
|
||||
|
||||
156
mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
Normal file
156
mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
Normal file
@@ -0,0 +1,156 @@
|
||||
//===- LowerVectorBroadcast.cpp - Lower 'vector.broadcast' operation ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements target-independent rewrites and utilities to lower the
|
||||
// 'vector.broadcast' operation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#define DEBUG_TYPE "vector-broadcast-lowering"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
namespace {
|
||||
/// Progressive lowering of BroadcastOp.
|
||||
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::BroadcastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
VectorType dstType = op.getResultVectorType();
|
||||
VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
|
||||
Type eltType = dstType.getElementType();
|
||||
|
||||
// Scalar to any vector can use splat.
|
||||
if (!srcType) {
|
||||
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
|
||||
return success();
|
||||
}
|
||||
|
||||
// Determine rank of source and destination.
|
||||
int64_t srcRank = srcType.getRank();
|
||||
int64_t dstRank = dstType.getRank();
|
||||
|
||||
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
|
||||
if (srcRank <= 1 && dstRank == 1) {
|
||||
Value ext;
|
||||
if (srcRank == 0)
|
||||
ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
|
||||
else
|
||||
ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
|
||||
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Duplicate this rank.
|
||||
// For example:
|
||||
// %x = broadcast %y : k-D to n-D, k < n
|
||||
// becomes:
|
||||
// %b = broadcast %y : k-D to (n-1)-D
|
||||
// %x = [%b,%b,%b,%b] : n-D
|
||||
// becomes:
|
||||
// %b = [%y,%y] : (n-1)-D
|
||||
// %x = [%b,%b,%b,%b] : n-D
|
||||
if (srcRank < dstRank) {
|
||||
// Duplication.
|
||||
VectorType resType =
|
||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||
Value bcst =
|
||||
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, dstType, rewriter.getZeroAttr(dstType));
|
||||
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
|
||||
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Find non-matching dimension, if any.
|
||||
assert(srcRank == dstRank);
|
||||
int64_t m = -1;
|
||||
for (int64_t r = 0; r < dstRank; r++)
|
||||
if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
|
||||
m = r;
|
||||
break;
|
||||
}
|
||||
|
||||
// All trailing dimensions are the same. Simply pass through.
|
||||
if (m == -1) {
|
||||
rewriter.replaceOp(op, op.getSource());
|
||||
return success();
|
||||
}
|
||||
|
||||
// Any non-matching dimension forces a stretch along this rank.
|
||||
// For example:
|
||||
// %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
|
||||
// becomes:
|
||||
// %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
|
||||
// %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
|
||||
// %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
|
||||
// %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
|
||||
// %x = [%a,%b,%c,%d]
|
||||
// becomes:
|
||||
// %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
|
||||
// %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
|
||||
// %a = [%u, %v]
|
||||
// ..
|
||||
// %x = [%a,%b,%c,%d]
|
||||
VectorType resType =
|
||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, dstType, rewriter.getZeroAttr(dstType));
|
||||
if (m == 0) {
|
||||
// Stetch at start.
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
|
||||
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
|
||||
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
|
||||
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
|
||||
} else {
|
||||
// Stetch not at start.
|
||||
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
|
||||
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
|
||||
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorBroadcastLoweringPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
|
||||
}
|
||||
1329
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Normal file
1329
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Normal file
File diff suppressed because it is too large
Load Diff
173
mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
Normal file
173
mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
Normal file
@@ -0,0 +1,173 @@
|
||||
//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements target-independent rewrites and utilities to lower the
|
||||
// 'vector.scan' operation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#define DEBUG_TYPE "vector-broadcast-lowering"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
namespace {
|
||||
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
|
||||
/// outermost dimension. For example:
|
||||
/// ```
|
||||
/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
|
||||
/// ... into vector<2x3xf32>
|
||||
///
|
||||
/// ==>
|
||||
///
|
||||
/// %0 = arith.constant dense<0.0> : vector<2x3xf32>
|
||||
/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
|
||||
/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
|
||||
/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
|
||||
/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
|
||||
/// ```
|
||||
///
|
||||
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
|
||||
struct FlattenGather : OpRewritePattern<vector::GatherOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::GatherOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
VectorType resultTy = op.getType();
|
||||
if (resultTy.getRank() < 2)
|
||||
return rewriter.notifyMatchFailure(op, "already flat");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value indexVec = op.getIndexVec();
|
||||
Value maskVec = op.getMask();
|
||||
Value passThruVec = op.getPassThru();
|
||||
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, resultTy, rewriter.getZeroAttr(resultTy));
|
||||
|
||||
Type subTy = VectorType::get(resultTy.getShape().drop_front(),
|
||||
resultTy.getElementType());
|
||||
|
||||
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
|
||||
int64_t thisIdx[1] = {i};
|
||||
|
||||
Value indexSubVec =
|
||||
rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
|
||||
Value maskSubVec =
|
||||
rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
|
||||
Value passThruSubVec =
|
||||
rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
|
||||
Value subGather = rewriter.create<vector::GatherOp>(
|
||||
loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
|
||||
passThruSubVec);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
|
||||
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
|
||||
/// loads/extracts are made conditional using `scf.if` ops.
|
||||
struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::GatherOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
VectorType resultTy = op.getType();
|
||||
if (resultTy.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "unsupported rank");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Type elemTy = resultTy.getElementType();
|
||||
// Vector type with a single element. Used to generate `vector.loads`.
|
||||
VectorType elemVecTy = VectorType::get({1}, elemTy);
|
||||
|
||||
Value condMask = op.getMask();
|
||||
Value base = op.getBase();
|
||||
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
|
||||
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
|
||||
op.getIndexVec());
|
||||
auto baseOffsets = llvm::to_vector(op.getIndices());
|
||||
Value lastBaseOffset = baseOffsets.back();
|
||||
|
||||
Value result = op.getPassThru();
|
||||
|
||||
// Emit a conditional access for each vector element.
|
||||
for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
|
||||
int64_t thisIdx[1] = {i};
|
||||
Value condition =
|
||||
rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
|
||||
Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
|
||||
baseOffsets.back() =
|
||||
rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
|
||||
|
||||
auto loadBuilder = [&](OpBuilder &b, Location loc) {
|
||||
Value extracted;
|
||||
if (isa<MemRefType>(base.getType())) {
|
||||
// `vector.load` does not support scalar result; emit a vector load
|
||||
// and extract the single result instead.
|
||||
Value load =
|
||||
b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
|
||||
int64_t zeroIdx[1] = {0};
|
||||
extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
|
||||
} else {
|
||||
extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
|
||||
}
|
||||
|
||||
Value newResult =
|
||||
b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
|
||||
b.create<scf::YieldOp>(loc, newResult);
|
||||
};
|
||||
auto passThruBuilder = [result](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(loc, result);
|
||||
};
|
||||
|
||||
result =
|
||||
rewriter
|
||||
.create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
|
||||
/*elseBuilder=*/passThruBuilder)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorGatherLoweringPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
|
||||
benefit);
|
||||
}
|
||||
@@ -6,7 +6,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements target-independent rewrites and utilitites to lower the
|
||||
// This file implements target-independent rewrites and utilities to lower the
|
||||
// 'vector.mask' operation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/Passes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
@@ -30,6 +31,147 @@ namespace vector {
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// populateVectorMaskOpLoweringPatterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Progressive lowering of CreateMaskOp.
|
||||
/// One:
|
||||
/// %x = vector.create_mask %a, ... : vector<dx...>
|
||||
/// is replaced by:
|
||||
/// %l = vector.create_mask ... : vector<...> ; one lower rank
|
||||
/// %0 = arith.cmpi "slt", %ci, %a |
|
||||
/// %1 = select %0, %l, %zeroes |
|
||||
/// %r = vector.insert %1, %pr [i] | d-times
|
||||
/// %x = ....
|
||||
/// until a one-dimensional vector is reached.
|
||||
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getResult().getType().cast<VectorType>();
|
||||
int64_t rank = dstType.getRank();
|
||||
if (rank <= 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "0-D and 1-D vectors are handled separately");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto eltType = dstType.getElementType();
|
||||
int64_t dim = dstType.getDimSize(0);
|
||||
Value idx = op.getOperand(0);
|
||||
|
||||
VectorType lowType =
|
||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||
Value trueVal = rewriter.create<vector::CreateMaskOp>(
|
||||
loc, lowType, op.getOperands().drop_front());
|
||||
Value falseVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, lowType, rewriter.getZeroAttr(lowType));
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, dstType, rewriter.getZeroAttr(dstType));
|
||||
for (int64_t d = 0; d < dim; d++) {
|
||||
Value bnd =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
|
||||
Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
bnd, idx);
|
||||
Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
|
||||
auto pos = rewriter.getI64ArrayAttr(d);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of ConstantMaskOp.
|
||||
/// One:
|
||||
/// %x = vector.constant_mask [a,b]
|
||||
/// is replaced by:
|
||||
/// %z = zero-result
|
||||
/// %l = vector.constant_mask [b]
|
||||
/// %4 = vector.insert %l, %z[0]
|
||||
/// ..
|
||||
/// %x = vector.insert %l, %..[a-1]
|
||||
/// until a one-dimensional vector is reached. All these operations
|
||||
/// will be folded at LLVM IR level.
|
||||
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto dstType = op.getType();
|
||||
auto eltType = dstType.getElementType();
|
||||
auto dimSizes = op.getMaskDimSizes();
|
||||
int64_t rank = dstType.getRank();
|
||||
|
||||
if (rank == 0) {
|
||||
assert(dimSizes.size() == 1 &&
|
||||
"Expected exactly one dim size for a 0-D vector");
|
||||
bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, dstType,
|
||||
DenseIntElementsAttr::get(
|
||||
VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
|
||||
ArrayRef<bool>{value}));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Scalable constant masks can only be lowered for the "none set" case.
|
||||
if (dstType.cast<VectorType>().isScalable()) {
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, DenseElementsAttr::get(dstType, false));
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t trueDim = std::min(dstType.getDimSize(0),
|
||||
dimSizes[0].cast<IntegerAttr>().getInt());
|
||||
|
||||
if (rank == 1) {
|
||||
// Express constant 1-D case in explicit vector form:
|
||||
// [T,..,T,F,..,F].
|
||||
SmallVector<bool> values(dstType.getDimSize(0));
|
||||
for (int64_t d = 0; d < trueDim; d++)
|
||||
values[d] = true;
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, dstType, rewriter.getBoolVectorAttr(values));
|
||||
return success();
|
||||
}
|
||||
|
||||
VectorType lowType =
|
||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||
SmallVector<int64_t> newDimSizes;
|
||||
for (int64_t r = 1; r < rank; r++)
|
||||
newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
|
||||
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
|
||||
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, dstType, rewriter.getZeroAttr(dstType));
|
||||
for (int64_t d = 0; d < trueDim; d++) {
|
||||
auto pos = rewriter.getI64ArrayAttr(d);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorMaskOpLoweringPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
|
||||
patterns.getContext(), benefit);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// populateVectorMaskLoweringPatternsForSideEffectingOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
/// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
|
||||
//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
|
||||
//
|
||||
/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
|
||||
@@ -6,12 +6,13 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
/// This file implements target-independent rewrites of MultiDimReductionOp.
|
||||
// This file implements target-independent rewrites and utilities to lower the
|
||||
// 'vector.multi_reduction' operation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
@@ -19,6 +20,7 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// This file implements the following transformations as composable atomic
|
||||
/// patterns.
|
||||
|
||||
@@ -441,6 +443,7 @@ struct OneDimMultiReductionToTwoDim
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
|
||||
RewritePatternSet &patterns, VectorMultiReductionLowering options,
|
||||
251
mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
Normal file
251
mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
Normal file
@@ -0,0 +1,251 @@
|
||||
//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements target-independent rewrites and utilities to lower the
|
||||
// 'vector.scan' operation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#define DEBUG_TYPE "vector-broadcast-lowering"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
/// This function constructs the appropriate integer or float
|
||||
/// operation given the vector combining kind and operands. The
|
||||
/// supported int operations are : add, mul, min (signed/unsigned),
|
||||
/// max(signed/unsigned), and, or, xor. The supported float
|
||||
/// operations are : add, mul, min and max.
|
||||
static Value genOperator(Location loc, Value x, Value y,
|
||||
vector::CombiningKind kind,
|
||||
PatternRewriter &rewriter) {
|
||||
using vector::CombiningKind;
|
||||
|
||||
auto elType = x.getType().cast<VectorType>().getElementType();
|
||||
bool isInt = elType.isIntOrIndex();
|
||||
|
||||
Value combinedResult{nullptr};
|
||||
switch (kind) {
|
||||
case CombiningKind::ADD:
|
||||
if (isInt)
|
||||
combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
|
||||
else
|
||||
combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::MUL:
|
||||
if (isInt)
|
||||
combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
|
||||
else
|
||||
combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::MINUI:
|
||||
combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::MINSI:
|
||||
combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::MAXUI:
|
||||
combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::MAXSI:
|
||||
combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::AND:
|
||||
combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::OR:
|
||||
combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::XOR:
|
||||
combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::MINF:
|
||||
combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
|
||||
break;
|
||||
case CombiningKind::MAXF:
|
||||
combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
|
||||
break;
|
||||
}
|
||||
return combinedResult;
|
||||
}
|
||||
|
||||
/// This function checks to see if the vector combining kind
|
||||
/// is consistent with the integer or float element type.
|
||||
static bool isValidKind(bool isInt, vector::CombiningKind kind) {
|
||||
using vector::CombiningKind;
|
||||
enum class KindType { FLOAT, INT, INVALID };
|
||||
KindType type{KindType::INVALID};
|
||||
switch (kind) {
|
||||
case CombiningKind::MINF:
|
||||
case CombiningKind::MAXF:
|
||||
type = KindType::FLOAT;
|
||||
break;
|
||||
case CombiningKind::MINUI:
|
||||
case CombiningKind::MINSI:
|
||||
case CombiningKind::MAXUI:
|
||||
case CombiningKind::MAXSI:
|
||||
case CombiningKind::AND:
|
||||
case CombiningKind::OR:
|
||||
case CombiningKind::XOR:
|
||||
type = KindType::INT;
|
||||
break;
|
||||
case CombiningKind::ADD:
|
||||
case CombiningKind::MUL:
|
||||
type = isInt ? KindType::INT : KindType::FLOAT;
|
||||
break;
|
||||
}
|
||||
bool isValidIntKind = (type == KindType::INT) && isInt;
|
||||
bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
|
||||
return (isValidIntKind || isValidFloatKind);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
|
||||
/// vector.extract_strided_slice.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// %0:2 = vector.scan <add>, %arg0, %arg1
|
||||
/// {inclusive = true, reduction_dim = 1} :
|
||||
/// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
|
||||
/// ```
|
||||
///
|
||||
/// is converted to:
|
||||
///
|
||||
/// ```
|
||||
/// %cst = arith.constant dense<0> : vector<2x3xi32>
|
||||
/// %0 = vector.extract_strided_slice %arg0
|
||||
/// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
|
||||
/// : vector<2x3xi32> to vector<2x1xi32>
|
||||
/// %1 = vector.insert_strided_slice %0, %cst
|
||||
/// {offsets = [0, 0], strides = [1, 1]}
|
||||
/// : vector<2x1xi32> into vector<2x3xi32>
|
||||
/// %2 = vector.extract_strided_slice %arg0
|
||||
/// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
|
||||
/// : vector<2x3xi32> to vector<2x1xi32>
|
||||
/// %3 = arith.muli %0, %2 : vector<2x1xi32>
|
||||
/// %4 = vector.insert_strided_slice %3, %1
|
||||
/// {offsets = [0, 1], strides = [1, 1]}
|
||||
/// : vector<2x1xi32> into vector<2x3xi32>
|
||||
/// %5 = vector.extract_strided_slice %arg0
|
||||
/// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
|
||||
/// : vector<2x3xi32> to vector<2x1xi32>
|
||||
/// %6 = arith.muli %3, %5 : vector<2x1xi32>
|
||||
/// %7 = vector.insert_strided_slice %6, %4
|
||||
/// {offsets = [0, 2], strides = [1, 1]}
|
||||
/// : vector<2x1xi32> into vector<2x3xi32>
|
||||
/// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
|
||||
/// return %7, %8 : vector<2x3xi32>, vector<2xi32>
|
||||
/// ```
|
||||
struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ScanOp scanOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = scanOp.getLoc();
|
||||
VectorType destType = scanOp.getDestType();
|
||||
ArrayRef<int64_t> destShape = destType.getShape();
|
||||
auto elType = destType.getElementType();
|
||||
bool isInt = elType.isIntOrIndex();
|
||||
if (!isValidKind(isInt, scanOp.getKind()))
|
||||
return failure();
|
||||
|
||||
VectorType resType = VectorType::get(destShape, elType);
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, resType, rewriter.getZeroAttr(resType));
|
||||
int64_t reductionDim = scanOp.getReductionDim();
|
||||
bool inclusive = scanOp.getInclusive();
|
||||
int64_t destRank = destType.getRank();
|
||||
VectorType initialValueType = scanOp.getInitialValueType();
|
||||
int64_t initialValueRank = initialValueType.getRank();
|
||||
|
||||
SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
|
||||
reductionShape[reductionDim] = 1;
|
||||
VectorType reductionType = VectorType::get(reductionShape, elType);
|
||||
SmallVector<int64_t> offsets(destRank, 0);
|
||||
SmallVector<int64_t> strides(destRank, 1);
|
||||
SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
|
||||
sizes[reductionDim] = 1;
|
||||
ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
|
||||
ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
|
||||
|
||||
Value lastOutput, lastInput;
|
||||
for (int i = 0; i < destShape[reductionDim]; i++) {
|
||||
offsets[reductionDim] = i;
|
||||
ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
|
||||
Value input = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
|
||||
scanStrides);
|
||||
Value output;
|
||||
if (i == 0) {
|
||||
if (inclusive) {
|
||||
output = input;
|
||||
} else {
|
||||
if (initialValueRank == 0) {
|
||||
// ShapeCastOp cannot handle 0-D vectors
|
||||
output = rewriter.create<vector::BroadcastOp>(
|
||||
loc, input.getType(), scanOp.getInitialValue());
|
||||
} else {
|
||||
output = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, input.getType(), scanOp.getInitialValue());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Value y = inclusive ? input : lastInput;
|
||||
output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
|
||||
assert(output != nullptr);
|
||||
}
|
||||
result = rewriter.create<vector::InsertStridedSliceOp>(
|
||||
loc, output, result, offsets, strides);
|
||||
lastOutput = output;
|
||||
lastInput = input;
|
||||
}
|
||||
|
||||
Value reduction;
|
||||
if (initialValueRank == 0) {
|
||||
Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
|
||||
reduction =
|
||||
rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
|
||||
} else {
|
||||
reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
|
||||
lastOutput);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(scanOp, {result, reduction});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorScanLoweringPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
|
||||
}
|
||||
177
mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
Normal file
177
mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
Normal file
@@ -0,0 +1,177 @@
|
||||
//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements target-independent rewrites and utilities to lower the
|
||||
// 'vector.shape_cast' operation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#define DEBUG_TYPE "vector-shape-cast-lowering"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
namespace {
|
||||
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
|
||||
/// vectors progressively on the way to target llvm.matrix intrinsics.
|
||||
/// This iterates over the most major dimension of the 2-D vector and performs
|
||||
/// rewrites into:
|
||||
/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
|
||||
class ShapeCastOp2DDownCastRewritePattern
|
||||
: public OpRewritePattern<vector::ShapeCastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto sourceVectorType = op.getSourceVectorType();
|
||||
auto resultVectorType = op.getResultVectorType();
|
||||
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
Value desc = rewriter.create<arith::ConstantOp>(
|
||||
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
||||
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
|
||||
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
|
||||
Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
|
||||
desc = rewriter.create<vector::InsertStridedSliceOp>(
|
||||
loc, vec, desc,
|
||||
/*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
|
||||
}
|
||||
rewriter.replaceOp(op, desc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
|
||||
/// vectors progressively.
|
||||
/// This iterates over the most major dimension of the 2-D vector and performs
|
||||
/// rewrites into:
|
||||
/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
|
||||
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
|
||||
class ShapeCastOp2DUpCastRewritePattern
|
||||
: public OpRewritePattern<vector::ShapeCastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto sourceVectorType = op.getSourceVectorType();
|
||||
auto resultVectorType = op.getResultVectorType();
|
||||
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
Value desc = rewriter.create<arith::ConstantOp>(
|
||||
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
||||
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
|
||||
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
|
||||
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
|
||||
/*sizes=*/mostMinorVectorSize,
|
||||
/*strides=*/1);
|
||||
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
|
||||
}
|
||||
rewriter.replaceOp(op, desc);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// We typically should not lower general shape cast operations into data
|
||||
// movement instructions, since the assumption is that these casts are
|
||||
// optimized away during progressive lowering. For completeness, however,
|
||||
// we fall back to a reference implementation that moves all elements
|
||||
// into the right place if we get here.
|
||||
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto sourceVectorType = op.getSourceVectorType();
|
||||
auto resultVectorType = op.getResultVectorType();
|
||||
|
||||
// Special case 2D / 1D lowerings with better implementations.
|
||||
// TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
|
||||
int64_t srcRank = sourceVectorType.getRank();
|
||||
int64_t resRank = resultVectorType.getRank();
|
||||
if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
|
||||
return failure();
|
||||
|
||||
// Generic ShapeCast lowering path goes all the way down to unrolled scalar
|
||||
// extract/insert chains.
|
||||
// TODO: consider evolving the semantics to only allow 1D source or dest and
|
||||
// drop this potentially very expensive lowering.
|
||||
// Compute number of elements involved in the reshape.
|
||||
int64_t numElts = 1;
|
||||
for (int64_t r = 0; r < srcRank; r++)
|
||||
numElts *= sourceVectorType.getDimSize(r);
|
||||
// Replace with data movement operations:
|
||||
// x[0,0,0] = y[0,0]
|
||||
// x[0,0,1] = y[0,1]
|
||||
// x[0,1,0] = y[0,2]
|
||||
// etc., incrementing the two index vectors "row-major"
|
||||
// within the source and result shape.
|
||||
SmallVector<int64_t> srcIdx(srcRank);
|
||||
SmallVector<int64_t> resIdx(resRank);
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
||||
for (int64_t i = 0; i < numElts; i++) {
|
||||
if (i != 0) {
|
||||
incIdx(srcIdx, sourceVectorType, srcRank - 1);
|
||||
incIdx(resIdx, resultVectorType, resRank - 1);
|
||||
}
|
||||
Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
|
||||
result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
|
||||
assert(0 <= r && r < tp.getRank());
|
||||
if (++idx[r] == tp.getDimSize(r)) {
|
||||
idx[r] = 0;
|
||||
incIdx(idx, tp, r - 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorShapeCastLoweringPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
|
||||
patterns.getContext(), benefit);
|
||||
}
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -46,6 +46,11 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
|
||||
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// populateVectorTransferPermutationMapLoweringPatterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Lower transfer_read op with permutation into a transfer_read with a
|
||||
/// permutation map composed of leading zeros followed by a minor identiy +
|
||||
/// vector.transpose op.
|
||||
@@ -332,6 +337,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns
|
||||
@@ -339,3 +346,239 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
|
||||
TransferOpReduceRank, TransferWriteNonPermutationLowering>(
|
||||
patterns.getContext(), benefit);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// populateVectorTransferLoweringPatterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Progressive lowering of transfer_read. This pattern supports lowering of
|
||||
/// `vector.transfer_read` to a combination of `vector.load` and
|
||||
/// `vector.broadcast` if all of the following hold:
|
||||
/// - Stride of most minor memref dimension must be 1.
|
||||
/// - Out-of-bounds masking is not required.
|
||||
/// - If the memref's element type is a vector type then it coincides with the
|
||||
/// result type.
|
||||
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
|
||||
struct TransferReadToVectorLoadLowering
|
||||
: public OpRewritePattern<vector::TransferReadOp> {
|
||||
TransferReadToVectorLoadLowering(MLIRContext *context,
|
||||
std::optional<unsigned> maxRank,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
|
||||
maxTransferRank(maxRank) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp read,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
|
||||
return failure();
|
||||
|
||||
SmallVector<unsigned> broadcastedDims;
|
||||
// Permutations are handled by VectorToSCF or
|
||||
// populateVectorTransferPermutationMapLoweringPatterns.
|
||||
// We let the 0-d corner case pass-through as it is supported.
|
||||
if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
|
||||
&broadcastedDims))
|
||||
return failure();
|
||||
|
||||
auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
|
||||
// Non-unit strides are handled by VectorToSCF.
|
||||
if (!vector::isLastMemrefDimUnitStride(memRefType))
|
||||
return failure();
|
||||
|
||||
// If there is broadcasting involved then we first load the unbroadcasted
|
||||
// vector, and then broadcast it with `vector.broadcast`.
|
||||
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
|
||||
SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
|
||||
vectorShape.end());
|
||||
for (unsigned i : broadcastedDims)
|
||||
unbroadcastedVectorShape[i] = 1;
|
||||
VectorType unbroadcastedVectorType = VectorType::get(
|
||||
unbroadcastedVectorShape, read.getVectorType().getElementType());
|
||||
|
||||
// `vector.load` supports vector types as memref's elements only when the
|
||||
// resulting vector type is the same as the element type.
|
||||
auto memrefElTy = memRefType.getElementType();
|
||||
if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
|
||||
return failure();
|
||||
|
||||
// Otherwise, element types of the memref and the vector must match.
|
||||
if (!memrefElTy.isa<VectorType>() &&
|
||||
memrefElTy != read.getVectorType().getElementType())
|
||||
return failure();
|
||||
|
||||
// Out-of-bounds dims are handled by MaterializeTransferMask.
|
||||
if (read.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
|
||||
// Create vector load op.
|
||||
Operation *loadOp;
|
||||
if (read.getMask()) {
|
||||
Value fill = rewriter.create<vector::SplatOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.getPadding());
|
||||
loadOp = rewriter.create<vector::MaskedLoadOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.getSource(),
|
||||
read.getIndices(), read.getMask(), fill);
|
||||
} else {
|
||||
loadOp = rewriter.create<vector::LoadOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.getSource(),
|
||||
read.getIndices());
|
||||
}
|
||||
|
||||
// Insert a broadcasting op if required.
|
||||
if (!broadcastedDims.empty()) {
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
read, read.getVectorType(), loadOp->getResult(0));
|
||||
} else {
|
||||
rewriter.replaceOp(read, loadOp->getResult(0));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
std::optional<unsigned> maxTransferRank;
|
||||
};
|
||||
|
||||
/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
|
||||
// TODO: we shouldn't cross the vector/scalar domains just for this
|
||||
// but atm we lack the infra to avoid it. Possible solutions include:
|
||||
// - go directly to LLVM + bitcast
|
||||
// - introduce a bitcast op and likely a new pointer dialect
|
||||
// - let memref.load/store additionally support the 0-d vector case
|
||||
// There are still deeper data layout issues lingering even in this
|
||||
// trivial case (for architectures for which this matters).
|
||||
struct VectorLoadToMemrefLoadLowering
|
||||
: public OpRewritePattern<vector::LoadOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto vecType = loadOp.getVectorType();
|
||||
if (vecType.getNumElements() != 1)
|
||||
return failure();
|
||||
auto memrefLoad = rewriter.create<memref::LoadOp>(
|
||||
loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
|
||||
memrefLoad);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
|
||||
struct VectorStoreToMemrefStoreLowering
|
||||
: public OpRewritePattern<vector::StoreOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto vecType = storeOp.getVectorType();
|
||||
if (vecType.getNumElements() != 1)
|
||||
return failure();
|
||||
Value extracted;
|
||||
if (vecType.getRank() == 0) {
|
||||
// TODO: Unifiy once ExtractOp supports 0-d vectors.
|
||||
extracted = rewriter.create<vector::ExtractElementOp>(
|
||||
storeOp.getLoc(), storeOp.getValueToStore());
|
||||
} else {
|
||||
SmallVector<int64_t> indices(vecType.getRank(), 0);
|
||||
extracted = rewriter.create<vector::ExtractOp>(
|
||||
storeOp.getLoc(), storeOp.getValueToStore(), indices);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of transfer_write. This pattern supports lowering of
|
||||
/// `vector.transfer_write` to `vector.store` if all of the following hold:
|
||||
/// - Stride of most minor memref dimension must be 1.
|
||||
/// - Out-of-bounds masking is not required.
|
||||
/// - If the memref's element type is a vector type then it coincides with the
|
||||
/// type of the written value.
|
||||
/// - The permutation map is the minor identity map (neither permutation nor
|
||||
/// broadcasting is allowed).
|
||||
struct TransferWriteToVectorStoreLowering
|
||||
: public OpRewritePattern<vector::TransferWriteOp> {
|
||||
TransferWriteToVectorStoreLowering(MLIRContext *context,
|
||||
std::optional<unsigned> maxRank,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
|
||||
maxTransferRank(maxRank) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
|
||||
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
|
||||
diag << "rank exceeds maxTransferRank: " << write;
|
||||
});
|
||||
|
||||
// Permutations are handled by VectorToSCF or
|
||||
// populateVectorTransferPermutationMapLoweringPatterns.
|
||||
if ( // pass-through for the 0-d corner case.
|
||||
!write.getPermutationMap().isMinorIdentity())
|
||||
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
|
||||
diag << "permutation map is not minor identity: " << write;
|
||||
});
|
||||
|
||||
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
|
||||
diag << "not a memref type: " << write;
|
||||
});
|
||||
|
||||
// Non-unit strides are handled by VectorToSCF.
|
||||
if (!vector::isLastMemrefDimUnitStride(memRefType))
|
||||
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
|
||||
diag << "most minor stride is not 1: " << write;
|
||||
});
|
||||
|
||||
// `vector.store` supports vector types as memref's elements only when the
|
||||
// type of the vector value being written is the same as the element type.
|
||||
auto memrefElTy = memRefType.getElementType();
|
||||
if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
|
||||
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
|
||||
diag << "elemental type mismatch: " << write;
|
||||
});
|
||||
|
||||
// Otherwise, element types of the memref and the vector must match.
|
||||
if (!memrefElTy.isa<VectorType>() &&
|
||||
memrefElTy != write.getVectorType().getElementType())
|
||||
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
|
||||
diag << "elemental type mismatch: " << write;
|
||||
});
|
||||
|
||||
// Out-of-bounds dims are handled by MaterializeTransferMask.
|
||||
if (write.hasOutOfBoundsDim())
|
||||
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
|
||||
diag << "out of bounds dim: " << write;
|
||||
});
|
||||
if (write.getMask()) {
|
||||
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
|
||||
write, write.getSource(), write.getIndices(), write.getMask(),
|
||||
write.getVector());
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<vector::StoreOp>(
|
||||
write, write.getVector(), write.getSource(), write.getIndices());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
std::optional<unsigned> maxTransferRank;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorTransferLoweringPatterns(
|
||||
RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<TransferReadToVectorLoadLowering,
|
||||
TransferWriteToVectorStoreLowering>(patterns.getContext(),
|
||||
maxTransferRank, benefit);
|
||||
patterns
|
||||
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
|
||||
patterns.getContext(), benefit);
|
||||
}
|
||||
210
mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Normal file
210
mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Normal file
@@ -0,0 +1,210 @@
|
||||
//===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements target-independent rewrites and utilities to lower the
|
||||
// 'vector.transpose' operation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#define DEBUG_TYPE "vector-shape-cast-lowering"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
|
||||
/// transposed.
|
||||
static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
|
||||
SmallVectorImpl<int64_t> &result) {
|
||||
size_t numTransposedDims = transpose.size();
|
||||
for (size_t transpDim : llvm::reverse(transpose)) {
|
||||
if (transpDim != numTransposedDims - 1)
|
||||
break;
|
||||
numTransposedDims--;
|
||||
}
|
||||
|
||||
result.append(transpose.begin(), transpose.begin() + numTransposedDims);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Progressive lowering of TransposeOp.
|
||||
/// One:
|
||||
/// %x = vector.transpose %y, [1, 0]
|
||||
/// is replaced by:
|
||||
/// %z = arith.constant dense<0.000000e+00>
|
||||
/// %0 = vector.extract %y[0, 0]
|
||||
/// %1 = vector.insert %0, %z [0, 0]
|
||||
/// ..
|
||||
/// %x = vector.insert .., .. [.., ..]
|
||||
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<vector::TransposeOp>(context, benefit),
|
||||
vectorTransformOptions(vectorTransformOptions) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value input = op.getVector();
|
||||
VectorType inputType = op.getSourceVectorType();
|
||||
VectorType resType = op.getResultVectorType();
|
||||
|
||||
// Set up convenience transposition table.
|
||||
SmallVector<int64_t> transp;
|
||||
for (auto attr : op.getTransp())
|
||||
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
|
||||
if (vectorTransformOptions.vectorTransposeLowering ==
|
||||
vector::VectorTransposeLowering::Shuffle &&
|
||||
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Options specifies lowering to shuffle");
|
||||
|
||||
// Handle a true 2-D matrix transpose differently when requested.
|
||||
if (vectorTransformOptions.vectorTransposeLowering ==
|
||||
vector::VectorTransposeLowering::Flat &&
|
||||
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
|
||||
Type flattenedType =
|
||||
VectorType::get(resType.getNumElements(), resType.getElementType());
|
||||
auto matrix =
|
||||
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
|
||||
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
|
||||
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
|
||||
Value trans = rewriter.create<vector::FlatTransposeOp>(
|
||||
loc, flattenedType, matrix, rows, columns);
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Generate unrolled extract/insert ops. We do not unroll the rightmost
|
||||
// (i.e., highest-order) dimensions that are not transposed and leave them
|
||||
// in vector form to improve performance. Therefore, we prune those
|
||||
// dimensions from the shape/transpose data structures used to generate the
|
||||
// extract/insert ops.
|
||||
SmallVector<int64_t> prunedTransp;
|
||||
pruneNonTransposedDims(transp, prunedTransp);
|
||||
size_t numPrunedDims = transp.size() - prunedTransp.size();
|
||||
auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
|
||||
auto prunedInStrides = computeStrides(prunedInShape);
|
||||
|
||||
// Generates the extract/insert operations for every scalar/vector element
|
||||
// of the leftmost transposed dimensions. We traverse every transpose
|
||||
// element using a linearized index that we delinearize to generate the
|
||||
// appropriate indices for the extract/insert operations.
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, resType, rewriter.getZeroAttr(resType));
|
||||
int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
|
||||
|
||||
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
|
||||
++linearIdx) {
|
||||
auto extractIdxs = delinearize(linearIdx, prunedInStrides);
|
||||
SmallVector<int64_t> insertIdxs(extractIdxs);
|
||||
applyPermutationToVector(insertIdxs, prunedTransp);
|
||||
Value extractOp =
|
||||
rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
};
|
||||
|
||||
/// Rewrite a 2-D vector.transpose as a sequence of:
|
||||
/// vector.shape_cast 2D -> 1D
|
||||
/// vector.shuffle
|
||||
/// vector.shape_cast 1D -> 2D
|
||||
class TransposeOp2DToShuffleLowering
|
||||
: public OpRewritePattern<vector::TransposeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
TransposeOp2DToShuffleLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<vector::TransposeOp>(context, benefit),
|
||||
vectorTransformOptions(vectorTransformOptions) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
VectorType srcType = op.getSourceVectorType();
|
||||
if (srcType.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
|
||||
|
||||
SmallVector<int64_t> transp;
|
||||
for (auto attr : op.getTransp())
|
||||
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
if (transp[0] != 1 && transp[1] != 0)
|
||||
return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
|
||||
|
||||
if (vectorTransformOptions.vectorTransposeLowering !=
|
||||
VectorTransposeLowering::Shuffle)
|
||||
return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
|
||||
|
||||
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
|
||||
Value casted = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, VectorType::get({m * n}, srcType.getElementType()),
|
||||
op.getVector());
|
||||
SmallVector<int64_t> mask;
|
||||
mask.reserve(m * n);
|
||||
for (int64_t j = 0; j < n; ++j)
|
||||
for (int64_t i = 0; i < m; ++i)
|
||||
mask.push_back(i * n + j);
|
||||
|
||||
Value shuffled =
|
||||
rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
|
||||
op, op.getResultVectorType(), shuffled);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns, VectorTransformsOptions options,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
|
||||
options, patterns.getContext(), benefit);
|
||||
}
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <type_traits>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
@@ -92,11 +92,11 @@ static Value createInBoundsCond(RewriterBase &b,
|
||||
}
|
||||
|
||||
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
|
||||
/// masking) fastpath and a slowpath.
|
||||
/// masking) fast path and a slow path.
|
||||
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
|
||||
/// newly created conditional upon function return.
|
||||
/// To accomodate for the fact that the original vector.transfer indexing may be
|
||||
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
|
||||
/// To accommodate for the fact that the original vector.transfer indexing may
|
||||
/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
|
||||
/// scf.if op returns a view and values of type index.
|
||||
/// At this time, only vector.transfer_read case is implemented.
|
||||
///
|
||||
@@ -107,11 +107,11 @@ static Value createInBoundsCond(RewriterBase &b,
|
||||
/// is transformed into:
|
||||
/// ```
|
||||
/// %1:3 = scf.if (%inBounds) {
|
||||
/// // fastpath, direct cast
|
||||
/// // fast path, direct cast
|
||||
/// memref.cast %A: memref<A...> to compatibleMemRefType
|
||||
/// scf.yield %view : compatibleMemRefType, index, index
|
||||
/// } else {
|
||||
/// // slowpath, not in-bounds vector.transfer or linalg.copy.
|
||||
/// // slow path, not in-bounds vector.transfer or linalg.copy.
|
||||
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// scf.yield %4 : compatibleMemRefType, index, index
|
||||
// }
|
||||
@@ -172,12 +172,10 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
|
||||
for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
|
||||
resShape[idx] =
|
||||
(aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
|
||||
resStrides[idx] = (aStrides[idx] == bStrides[idx])
|
||||
? aStrides[idx]
|
||||
: ShapedType::kDynamic;
|
||||
resStrides[idx] =
|
||||
(aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
|
||||
}
|
||||
resOffset =
|
||||
(aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
|
||||
resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
|
||||
return MemRefType::get(
|
||||
resShape, aT.getElementType(),
|
||||
StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
|
||||
@@ -634,7 +632,34 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
|
||||
namespace {
|
||||
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
|
||||
/// may take an extra filter to perform selection at a finer granularity.
|
||||
struct VectorTransferFullPartialRewriter : public RewritePattern {
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(VectorTransferOpInterface op)>;
|
||||
|
||||
explicit VectorTransferFullPartialRewriter(
|
||||
MLIRContext *context,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
FilterConstraintType filter =
|
||||
[](VectorTransferOpInterface op) { return success(); },
|
||||
PatternBenefit benefit = 1)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
|
||||
filter(std::move(filter)) {}
|
||||
|
||||
/// Performs the rewrite.
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
VectorTransformsOptions options;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
|
||||
if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
|
||||
@@ -642,3 +667,9 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
|
||||
return failure();
|
||||
return splitFullAndPartialTransfer(rewriter, xferOp, options);
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorTransferFullPartialPatterns(
|
||||
RewritePatternSet &patterns, const VectorTransformsOptions &options) {
|
||||
patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
|
||||
options);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,7 @@
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
@@ -148,8 +149,9 @@ struct TestVectorContractionLowering
|
||||
if (lowerToOuterProduct) {
|
||||
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
|
||||
VectorTransformsOptions options{lowering};
|
||||
patterns.add<ContractionOpToOuterProductOpLowering>(options,
|
||||
&getContext());
|
||||
populateVectorContractLoweringPatterns(
|
||||
patterns, options, /*benefit=*/1,
|
||||
/*disableOuterProductlowering=*/true);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
return;
|
||||
}
|
||||
@@ -469,7 +471,7 @@ struct TestVectorTransferFullPartialSplitPatterns
|
||||
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
|
||||
else
|
||||
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
|
||||
patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
|
||||
populateVectorTransferFullPartialPatterns(patterns, options);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -8539,6 +8539,7 @@ cc_library(
|
||||
":TransformDialect",
|
||||
":TransformDialectUtils",
|
||||
":TransformUtils",
|
||||
":VectorTransforms",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user