[mlir][VectorOps] Support vector transfer_read/write unrolling for memrefs with vector element type.

Summary:
[mlir][VectorOps] Support vector transfer_read/write unrolling for memrefs with vector element type.  When unrolling vector transfer read/write on memrefs with vector element type, the indices used to index the memref argument must be updated to reflect the unrolled operation.   However, in the case of memrefs with vector element type, we need to be careful to only update the relevant memref indices.

For example, a vector transfer read with the following source/result types, memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32>, should only update memref indices 1 and 2 during unrolling.

Reviewers: nicolasvasilache, aartbik

Reviewed By: nicolasvasilache, aartbik

Subscribers: lebedev.ri, Joonsoo, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72965
This commit is contained in:
Andy Davis
2020-02-05 16:06:40 -08:00
parent 643dee903c
commit f9efce1dd5
2 changed files with 101 additions and 17 deletions

View File

@@ -460,12 +460,13 @@ SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
op, iterationBounds, vectors, resultIndex, targetShape, builder)};
}
// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
// calls 'fn' with linear index and indices for each slice.
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
/// calls 'fn' with linear index and indices for each slice.
static void
generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides,
ArrayRef<Value> indices, PatternRewriter &rewriter,
generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
TupleType tupleType, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides, ArrayRef<Value> indices,
PatternRewriter &rewriter,
function_ref<void(unsigned, ArrayRef<Value>)> fn) {
// Compute strides w.r.t. to slice counts in each dimension.
auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
@@ -475,6 +476,25 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
int64_t numSlices = tupleType.size();
unsigned numSliceIndices = indices.size();
// Compute 'indexOffset' at which to update 'indices', which is equal
// to the memref rank (indices.size) minus the effective 'vectorRank'.
// The effective 'vectorRank', is equal to the rank of the vector type
// minus the rank of the memref vector element type (if it has one).
//
// For example:
//
// Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector
// transfer_read/write ops which read/write vectors of type
// 'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective
// vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
//
unsigned vectorRank = vectorType.getRank();
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
assert(vectorRank >= memrefVectorElementType.getRank());
vectorRank -= memrefVectorElementType.getRank();
}
unsigned indexOffset = numSliceIndices - vectorRank;
auto *ctx = rewriter.getContext();
for (unsigned i = 0; i < numSlices; ++i) {
auto vectorOffsets = delinearize(sliceStrides, i);
@@ -482,18 +502,41 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
SmallVector<Value, 4> sliceIndices(numSliceIndices);
for (auto it : llvm::enumerate(indices)) {
auto expr = getAffineDimExpr(0, ctx) +
getAffineConstantExpr(elementOffsets[it.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
sliceIndices[it.index()] = rewriter.create<AffineApplyOp>(
it.value().getLoc(), map, ArrayRef<Value>(it.value()));
for (unsigned j = 0; j < numSliceIndices; ++j) {
if (j < indexOffset) {
sliceIndices[j] = indices[j];
} else {
auto expr = getAffineDimExpr(0, ctx) +
getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
sliceIndices[j] = rewriter.create<AffineApplyOp>(
indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
}
}
// Call 'fn' to generate slice 'i' at 'sliceIndices'.
fn(i, sliceIndices);
}
}
/// Returns true if 'map' is a suffix of an identity affine map, false
/// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)>
static bool isIdentitySuffix(AffineMap map) {
if (map.getNumDims() < map.getNumResults())
return false;
ArrayRef<AffineExpr> results = map.getResults();
Optional<int> lastPos;
for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
auto expr = results[i].dyn_cast<AffineDimExpr>();
if (!expr)
return false;
int currPos = static_cast<int>(expr.getPosition());
if (lastPos.hasValue() && currPos != lastPos.getValue() + 1)
return false;
lastPos = currPos;
}
return true;
}
namespace {
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
// scheme of its unique ExtractSlicesOp user.
@@ -504,7 +547,7 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
PatternRewriter &rewriter) const override {
// TODO(andydavis, ntv) Support splitting TransferReadOp with non-identity
// permutation maps. Repurpose code from MaterializeVectors transformation.
if (!xferReadOp.permutation_map().isIdentity())
if (!isIdentitySuffix(xferReadOp.permutation_map()))
return matchFailure();
// Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
Value xferReadResult = xferReadOp.getResult();
@@ -523,6 +566,8 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
Location loc = xferReadOp.getLoc();
auto memrefElementType =
xferReadOp.memref().getType().cast<MemRefType>().getElementType();
int64_t numSlices = resultTupleType.size();
SmallVector<Value, 4> vectorTupleValues(numSlices);
SmallVector<Value, 4> indices(xferReadOp.indices().begin(),
@@ -535,8 +580,9 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
xferReadOp.permutation_map(), xferReadOp.padding());
};
generateTransferOpSlices(sourceVectorType, resultTupleType, sizes, strides,
indices, rewriter, createSlice);
generateTransferOpSlices(memrefElementType, sourceVectorType,
resultTupleType, sizes, strides, indices, rewriter,
createSlice);
// Create tuple of splice xfer read operations.
Value tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
@@ -557,7 +603,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
PatternRewriter &rewriter) const override {
// TODO(andydavis, ntv) Support splitting TransferWriteOp with non-identity
// permutation maps. Repurpose code from MaterializeVectors transformation.
if (!xferWriteOp.permutation_map().isIdentity())
if (!isIdentitySuffix(xferWriteOp.permutation_map()))
return matchFailure();
// Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
@@ -580,6 +626,8 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
insertSlicesOp.getStrides(strides);
Location loc = xferWriteOp.getLoc();
auto memrefElementType =
xferWriteOp.memref().getType().cast<MemRefType>().getElementType();
SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
xferWriteOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
@@ -588,8 +636,9 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
xferWriteOp.permutation_map());
};
generateTransferOpSlices(resultVectorType, sourceTupleType, sizes, strides,
indices, rewriter, createSlice);
generateTransferOpSlices(memrefElementType, resultVectorType,
sourceTupleType, sizes, strides, indices, rewriter,
createSlice);
// Erase old 'xferWriteOp'.
rewriter.eraseOp(xferWriteOp);

View File

@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-LABEL: func @add4x2
// CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
@@ -311,3 +312,37 @@ func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%1 = vector.tuple_get %0, 1 : tuple<vector<4xf32>, vector<8xf32>>
return %1 : vector<8xf32>
}
// CHECK-LABEL: func @vector_transfers_vector_element_type
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[MAP1]]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {permutation_map = #[[MAP1]]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
func @vector_transfers_vector_element_type() {
%c0 = constant 0 : index
%cf0 = constant 0.000000e+00 : f32
%vf0 = splat %cf0 : vector<2x4xf32>
%0 = alloc() : memref<6x2x1xvector<2x4xf32>>
%1 = vector.transfer_read %0[%c0, %c0, %c0], %vf0
{permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}
: memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32>
%2 = vector.extract_slices %1, [1, 1, 2, 4], [1, 1, 1, 1]
: vector<2x1x2x4xf32> into tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
%3 = vector.tuple_get %2, 0 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
%4 = vector.tuple_get %2, 1 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
%5 = vector.tuple %3, %4 : vector<1x1x2x4xf32>, vector<1x1x2x4xf32>
%6 = vector.insert_slices %5, [1, 1, 2, 4], [1, 1, 1, 1]
: tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>> into vector<2x1x2x4xf32>
vector.transfer_write %6, %0[%c0, %c0, %c0]
{permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}
: vector<2x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
return
}