mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
[mlir][vector] Add lowering of Transfer_read with broadcast and permutation map
Convert transfer_read ops with permutation maps into simpler transfer_read with minority map + vector.braodcast and vector.transpose. And transfer_read with leading dimensions broacast into transfer_read of lower rank. Differential Revision: https://reviews.llvm.org/D99019
This commit is contained in:
@@ -113,6 +113,22 @@ public:
|
||||
bool isMinorIdentityWithBroadcasting(
|
||||
SmallVectorImpl<unsigned> *broadcastedDims = nullptr) const;
|
||||
|
||||
/// Return true if this affine map can be converted to a minor identity with
|
||||
/// broadcast by doing a permute. Return a permutation (there may be
|
||||
/// several) to apply to get to a minor identity with broadcasts.
|
||||
/// Ex:
|
||||
/// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with
|
||||
/// perm = [1, 0] and broadcast d2
|
||||
/// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by
|
||||
/// permutation + broadcast
|
||||
/// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3)
|
||||
/// with perm = [1, 0, 2] and broadcast d2
|
||||
/// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra
|
||||
/// leading broadcat dimensions. The map returned would be (0, 0, d0, d1)
|
||||
/// with perm = [3, 0, 1, 2]
|
||||
bool isPermutationOfMinorIdentityWithBroadcasting(
|
||||
SmallVectorImpl<unsigned> &permutedDims) const;
|
||||
|
||||
/// Returns true if this affine map is an empty map, i.e., () -> ().
|
||||
bool isEmpty() const;
|
||||
|
||||
|
||||
@@ -2842,6 +2842,113 @@ struct TransferWriteToVectorStoreLowering
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower transfer_read op with permutation into a transfer_read with a
|
||||
/// permutation map composed of leading zeros followed by a minor identiy +
|
||||
/// vector.transpose op.
|
||||
/// Ex:
|
||||
/// vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2) -> (0, d1)
|
||||
/// into:
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2) -> (d1, 0)
|
||||
/// vector.transpose %v, [1, 0]
|
||||
///
|
||||
/// vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
|
||||
/// into:
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
|
||||
/// vector.transpose %v, [0, 1, 3, 2, 4]
|
||||
/// Note that an alternative is to transform it to linalg.transpose +
|
||||
/// vector.transfer_read to do the transpose in memory instead.
|
||||
struct TransferReadPermutationLowering
|
||||
: public OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<unsigned> permutation;
|
||||
AffineMap map = op.permutation_map();
|
||||
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
|
||||
return failure();
|
||||
|
||||
AffineMap permutationMap =
|
||||
map.getPermutationMap(permutation, op.getContext());
|
||||
if (permutationMap.isIdentity())
|
||||
return failure();
|
||||
// Caluclate the map of the new read by applying the inverse permutation.
|
||||
permutationMap = inversePermutation(permutationMap);
|
||||
AffineMap newMap = permutationMap.compose(map);
|
||||
// Apply the reverse transpose to deduce the type of the transfer_read.
|
||||
ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
|
||||
SmallVector<int64_t> newVectorShape(originalShape.size());
|
||||
for (auto pos : llvm::enumerate(permutation)) {
|
||||
newVectorShape[pos.value()] = originalShape[pos.index()];
|
||||
}
|
||||
VectorType newReadType =
|
||||
VectorType::get(newVectorShape, op.getVectorType().getElementType());
|
||||
Value newRead = rewriter.create<vector::TransferReadOp>(
|
||||
op.getLoc(), newReadType, op.source(), op.indices(), newMap,
|
||||
op.padding(), op.masked() ? *op.masked() : ArrayAttr());
|
||||
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
|
||||
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
|
||||
transposePerm);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower transfer_read op with broadcast in the leading dimensions into
|
||||
/// transfer_read of lower rank + vector.broadcast.
|
||||
/// Ex: vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
|
||||
/// into:
|
||||
/// %v = vector.transfer_read ...
|
||||
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
|
||||
/// vector.broadcast %v
|
||||
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
AffineMap map = op.permutation_map();
|
||||
unsigned numLeadingBroadcast = 0;
|
||||
for (auto expr : map.getResults()) {
|
||||
auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
|
||||
if (!dimExpr || dimExpr.getValue() != 0)
|
||||
break;
|
||||
numLeadingBroadcast++;
|
||||
}
|
||||
// If there are no leading zeros in the map there is nothing to do.
|
||||
if (numLeadingBroadcast == 0)
|
||||
return failure();
|
||||
VectorType originalVecType = op.getVectorType();
|
||||
unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
|
||||
// Calculate new map, vector type and masks without the leading zeros.
|
||||
AffineMap newMap = AffineMap::get(
|
||||
map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
|
||||
op.getContext());
|
||||
// Only remove the leading zeros if the rest of the map is a minor identity
|
||||
// with broadasting. Otherwise we first want to permute the map.
|
||||
if (!newMap.isMinorIdentityWithBroadcasting())
|
||||
return failure();
|
||||
SmallVector<int64_t> newShape = llvm::to_vector<4>(
|
||||
originalVecType.getShape().take_back(reducedShapeRank));
|
||||
VectorType newReadType =
|
||||
VectorType::get(newShape, originalVecType.getElementType());
|
||||
ArrayAttr newMask =
|
||||
op.masked()
|
||||
? rewriter.getArrayAttr(
|
||||
op.maskedAttr().getValue().take_back(reducedShapeRank))
|
||||
: ArrayAttr();
|
||||
Value newRead = rewriter.create<vector::TransferReadOp>(
|
||||
op.getLoc(), newReadType, op.source(), op.indices(), newMap,
|
||||
op.padding(), newMask);
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
|
||||
newRead);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Trims leading one dimensions from `oldType` and returns the result type.
|
||||
// Returns `vector<1xT>` if `oldType` only has one element.
|
||||
static VectorType trimLeadingOneDims(VectorType oldType) {
|
||||
@@ -3317,6 +3424,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
|
||||
|
||||
void mlir::vector::populateVectorTransferLoweringPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<TransferReadToVectorLoadLowering,
|
||||
TransferWriteToVectorStoreLowering>(patterns.getContext());
|
||||
patterns
|
||||
.add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
|
||||
TransferReadPermutationLowering, TransferOpReduceRank>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@@ -140,6 +141,66 @@ bool AffineMap::isMinorIdentityWithBroadcasting(
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return true if this affine map can be converted to a minor identity with
|
||||
/// broadcast by doing a permute. Return a permutation (there may be
|
||||
/// several) to apply to get to a minor identity with broadcasts.
|
||||
/// Ex:
|
||||
/// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with
|
||||
/// perm = [1, 0] and broadcast d2
|
||||
/// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by
|
||||
/// permutation + broadcast
|
||||
/// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3)
|
||||
/// with perm = [1, 0, 2] and broadcast d2
|
||||
/// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra
|
||||
/// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with
|
||||
/// perm = [3, 0, 1, 2]
|
||||
bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting(
|
||||
SmallVectorImpl<unsigned> &permutedDims) const {
|
||||
unsigned projectionStart =
|
||||
getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0;
|
||||
permutedDims.clear();
|
||||
SmallVector<unsigned> broadcastDims;
|
||||
permutedDims.resize(getNumResults(), 0);
|
||||
// If there are more results than input dimensions we want the new map to
|
||||
// start with broadcast dimensions in order to be a minor identity with
|
||||
// broadcasting.
|
||||
unsigned leadingBroadcast =
|
||||
getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0;
|
||||
llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()),
|
||||
false);
|
||||
for (auto idxAndExpr : llvm::enumerate(getResults())) {
|
||||
unsigned resIdx = idxAndExpr.index();
|
||||
AffineExpr expr = idxAndExpr.value();
|
||||
// Each result may be either a constant 0 (broadcast dimension) or a
|
||||
// dimension.
|
||||
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
|
||||
if (constExpr.getValue() != 0)
|
||||
return false;
|
||||
broadcastDims.push_back(resIdx);
|
||||
} else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
|
||||
if (dimExpr.getPosition() < projectionStart)
|
||||
return false;
|
||||
unsigned newPosition =
|
||||
dimExpr.getPosition() - projectionStart + leadingBroadcast;
|
||||
permutedDims[resIdx] = newPosition;
|
||||
dimFound[newPosition] = true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Find a permuation for the broadcast dimension. Since they are broadcasted
|
||||
// any valid permutation is acceptable. We just permute the dim into a slot
|
||||
// without an existing dimension.
|
||||
unsigned pos = 0;
|
||||
for (auto dim : broadcastDims) {
|
||||
while (pos < dimFound.size() && dimFound[pos]) {
|
||||
pos++;
|
||||
}
|
||||
permutedDims[dim] = pos++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns an AffineMap representing a permutation.
|
||||
AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
|
||||
MLIRContext *context) {
|
||||
|
||||
@@ -206,3 +206,56 @@ func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index)
|
||||
%res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {masked = [false, false, false, false], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32>
|
||||
return %res : vector<3x2x4x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, 0, 0)>
|
||||
#map1 = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d0)>
|
||||
#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d1, 0, 0)>
|
||||
#map3 = affine_map<(d0, d1) -> (d1, d0, 0, 0)>
|
||||
#map4 = affine_map<(d0, d1) -> (0, d1, 0, d0)>
|
||||
#map5 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, 0, 0)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_permutations
|
||||
func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?xf32>)
|
||||
-> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
|
||||
vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>) {
|
||||
// CHECK-DAG: %[[CF0:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%c0 = constant 0 : index
|
||||
|
||||
%0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
|
||||
// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
|
||||
|
||||
%1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
|
||||
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
|
||||
|
||||
%2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {masked = [false, false, true, false], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} {masked = [false, true, false], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
|
||||
// CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32>
|
||||
// CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32>
|
||||
|
||||
%3 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map3} : memref<?x?xf32>, vector<7x14x8x16xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref<?x?xf32>, vector<14x7xf32>
|
||||
// CHECK: vector.broadcast %{{.*}} : vector<14x7xf32> to vector<8x16x14x7xf32>
|
||||
// CHECK: vector.transpose %{{.*}}, [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
|
||||
|
||||
%4 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map4} : memref<?x?xf32>, vector<7x14x8x16xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref<?x?xf32>, vector<16x14xf32>
|
||||
// CHECK: vector.broadcast %{{.*}} : vector<16x14xf32> to vector<7x8x16x14xf32>
|
||||
// CHECK: vector.transpose %{{.*}}, [0, 3, 1, 2] : vector<7x8x16x14xf32> to vector<7x14x8x16xf32>
|
||||
|
||||
%5 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map5} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[CF0]] : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
|
||||
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
|
||||
|
||||
return %0, %1, %2, %3, %4, %5 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
|
||||
vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
|
||||
vector<7x14x8x16xf32>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user