From f9efce1dd5ceef7ed594f42d207b13bb6c9f1b6e Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Wed, 5 Feb 2020 16:06:40 -0800 Subject: [PATCH] [mlir][VectorOps] Support vector transfer_read/write unrolling for memrefs with vector element type. Summary: [mlir][VectorOps] Support vector transfer_read/write unrolling for memrefs with vector element type. When unrolling vector transfer read/write on memrefs with vector element type, the indices used to index the memref argument must be updated to reflect the unrolled operation. However, in the case of memrefs with vector element type, we need to be careful to only update the relevant memref indices. For example, a vector transfer read with the following source/result types, memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32>, should only update memref indices 1 and 2 during unrolling. Reviewers: nicolasvasilache, aartbik Reviewed By: nicolasvasilache, aartbik Subscribers: lebedev.ri, Joonsoo, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72965 --- .../Dialect/VectorOps/VectorTransforms.cpp | 83 +++++++++++++++---- .../Dialect/VectorOps/vector-transforms.mlir | 35 ++++++++ 2 files changed, 101 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 670a0d0fa9bf..440c7707ce75 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -460,12 +460,13 @@ SmallVector mlir::vector::unrollSingleResultOpMatchingType( op, iterationBounds, vectors, resultIndex, targetShape, builder)}; } -// Generates slices of 'vectorType' according to 'sizes' and 'strides, and -// calls 'fn' with linear index and indices for each slice. +/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and +/// calls 'fn' with linear index and indices for each slice. static void -generateTransferOpSlices(VectorType vectorType, TupleType tupleType, - ArrayRef sizes, ArrayRef strides, - ArrayRef indices, PatternRewriter &rewriter, +generateTransferOpSlices(Type memrefElementType, VectorType vectorType, + TupleType tupleType, ArrayRef sizes, + ArrayRef strides, ArrayRef indices, + PatternRewriter &rewriter, function_ref)> fn) { // Compute strides w.r.t. to slice counts in each dimension. auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); @@ -475,6 +476,25 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType, int64_t numSlices = tupleType.size(); unsigned numSliceIndices = indices.size(); + // Compute 'indexOffset' at which to update 'indices', which is equal + // to the memref rank (indices.size) minus the effective 'vectorRank'. + // The effective 'vectorRank', is equal to the rank of the vector type + // minus the rank of the memref vector element type (if it has one). + // + // For example: + // + // Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector + // transfer_read/write ops which read/write vectors of type + // 'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective + // vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1. + // + unsigned vectorRank = vectorType.getRank(); + if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { + assert(vectorRank >= memrefVectorElementType.getRank()); + vectorRank -= memrefVectorElementType.getRank(); + } + unsigned indexOffset = numSliceIndices - vectorRank; + auto *ctx = rewriter.getContext(); for (unsigned i = 0; i < numSlices; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); @@ -482,18 +502,41 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType, computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. SmallVector sliceIndices(numSliceIndices); - for (auto it : llvm::enumerate(indices)) { - auto expr = getAffineDimExpr(0, ctx) + - getAffineConstantExpr(elementOffsets[it.index()], ctx); - auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - sliceIndices[it.index()] = rewriter.create( - it.value().getLoc(), map, ArrayRef(it.value())); + for (unsigned j = 0; j < numSliceIndices; ++j) { + if (j < indexOffset) { + sliceIndices[j] = indices[j]; + } else { + auto expr = getAffineDimExpr(0, ctx) + + getAffineConstantExpr(elementOffsets[j - indexOffset], ctx); + auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + sliceIndices[j] = rewriter.create( + indices[j].getLoc(), map, ArrayRef(indices[j])); + } } // Call 'fn' to generate slice 'i' at 'sliceIndices'. fn(i, sliceIndices); } } +/// Returns true if 'map' is a suffix of an identity affine map, false +/// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)> +static bool isIdentitySuffix(AffineMap map) { + if (map.getNumDims() < map.getNumResults()) + return false; + ArrayRef results = map.getResults(); + Optional lastPos; + for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { + auto expr = results[i].dyn_cast(); + if (!expr) + return false; + int currPos = static_cast(expr.getPosition()); + if (lastPos.hasValue() && currPos != lastPos.getValue() + 1) + return false; + lastPos = currPos; + } + return true; +} + namespace { // Splits vector TransferReadOp into smaller TransferReadOps based on slicing // scheme of its unique ExtractSlicesOp user. @@ -504,7 +547,7 @@ struct SplitTransferReadOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // TODO(andydavis, ntv) Support splitting TransferReadOp with non-identity // permutation maps. Repurpose code from MaterializeVectors transformation. - if (!xferReadOp.permutation_map().isIdentity()) + if (!isIdentitySuffix(xferReadOp.permutation_map())) return matchFailure(); // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. Value xferReadResult = xferReadOp.getResult(); @@ -523,6 +566,8 @@ struct SplitTransferReadOp : public OpRewritePattern { assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); Location loc = xferReadOp.getLoc(); + auto memrefElementType = + xferReadOp.memref().getType().cast().getElementType(); int64_t numSlices = resultTupleType.size(); SmallVector vectorTupleValues(numSlices); SmallVector indices(xferReadOp.indices().begin(), @@ -535,8 +580,9 @@ struct SplitTransferReadOp : public OpRewritePattern { loc, sliceVectorType, xferReadOp.memref(), sliceIndices, xferReadOp.permutation_map(), xferReadOp.padding()); }; - generateTransferOpSlices(sourceVectorType, resultTupleType, sizes, strides, - indices, rewriter, createSlice); + generateTransferOpSlices(memrefElementType, sourceVectorType, + resultTupleType, sizes, strides, indices, rewriter, + createSlice); // Create tuple of splice xfer read operations. Value tupleOp = rewriter.create(loc, resultTupleType, @@ -557,7 +603,7 @@ struct SplitTransferWriteOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // TODO(andydavis, ntv) Support splitting TransferWriteOp with non-identity // permutation maps. Repurpose code from MaterializeVectors transformation. - if (!xferWriteOp.permutation_map().isIdentity()) + if (!isIdentitySuffix(xferWriteOp.permutation_map())) return matchFailure(); // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'. auto *vectorDefOp = xferWriteOp.vector().getDefiningOp(); @@ -580,6 +626,8 @@ struct SplitTransferWriteOp : public OpRewritePattern { insertSlicesOp.getStrides(strides); Location loc = xferWriteOp.getLoc(); + auto memrefElementType = + xferWriteOp.memref().getType().cast().getElementType(); SmallVector indices(xferWriteOp.indices().begin(), xferWriteOp.indices().end()); auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { @@ -588,8 +636,9 @@ struct SplitTransferWriteOp : public OpRewritePattern { loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, xferWriteOp.permutation_map()); }; - generateTransferOpSlices(resultVectorType, sourceTupleType, sizes, strides, - indices, rewriter, createSlice); + generateTransferOpSlices(memrefElementType, resultVectorType, + sourceTupleType, sizes, strides, indices, rewriter, + createSlice); // Erase old 'xferWriteOp'. rewriter.eraseOp(xferWriteOp); diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir index 49a462424432..1153ffb2999f 100644 --- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s // CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)> // CHECK-LABEL: func @add4x2 // CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> @@ -311,3 +312,37 @@ func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { %1 = vector.tuple_get %0, 1 : tuple, vector<8xf32>> return %1 : vector<8xf32> } + +// CHECK-LABEL: func @vector_transfers_vector_element_type +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[MAP1]]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> +// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {permutation_map = #[[MAP1]]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> + +func @vector_transfers_vector_element_type() { + %c0 = constant 0 : index + %cf0 = constant 0.000000e+00 : f32 + %vf0 = splat %cf0 : vector<2x4xf32> + + %0 = alloc() : memref<6x2x1xvector<2x4xf32>> + + %1 = vector.transfer_read %0[%c0, %c0, %c0], %vf0 + {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} + : memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32> + + %2 = vector.extract_slices %1, [1, 1, 2, 4], [1, 1, 1, 1] + : vector<2x1x2x4xf32> into tuple, vector<1x1x2x4xf32>> + %3 = vector.tuple_get %2, 0 : tuple, vector<1x1x2x4xf32>> + %4 = vector.tuple_get %2, 1 : tuple, vector<1x1x2x4xf32>> + %5 = vector.tuple %3, %4 : vector<1x1x2x4xf32>, vector<1x1x2x4xf32> + %6 = vector.insert_slices %5, [1, 1, 2, 4], [1, 1, 1, 1] + : tuple, vector<1x1x2x4xf32>> into vector<2x1x2x4xf32> + + vector.transfer_write %6, %0[%c0, %c0, %c0] + {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} + : vector<2x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> + + return +}