diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 997b56a1ce14..c131fde517f8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -505,25 +505,61 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, return rewriter.create(loc, input, reassociation); } -/// Checks that the indices corresponding to dimensions starting at -/// `firstDimToCollapse` are constant 0, and writes to `outIndices` -/// the truncated indices where `firstDimToCollapse` is now the innermost dim. -/// TODO: Extract the logic that writes to outIndices so that this method -/// simply checks one pre-condition. -static LogicalResult -checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, - SmallVector &outIndices) { - int64_t rank = indices.size(); - if (firstDimToCollapse >= rank) - return failure(); - for (int64_t i = firstDimToCollapse; i < rank; ++i) { - std::optional cst = getConstantIntValue(indices[i]); - if (!cst || cst.value() != 0) - return failure(); +/// Returns the new indices that collapses the inner dimensions starting from +/// the `firstDimToCollapse` dimension. +static SmallVector getCollapsedIndices(RewriterBase &rewriter, + Location loc, + ArrayRef shape, + ValueRange indices, + int64_t firstDimToCollapse) { + assert(firstDimToCollapse < static_cast(indices.size())); + + // If all the collapsed indices are zero then no extra logic is needed. + // Otherwise, a new offset/index has to be computed. + SmallVector indicesAfterCollapsing( + indices.begin(), indices.begin() + firstDimToCollapse); + SmallVector indicesToCollapse(indices.begin() + firstDimToCollapse, + indices.end()); + if (llvm::all_of(indicesToCollapse, isZeroIndex)) { + indicesAfterCollapsing.push_back(indicesToCollapse[0]); + return indicesAfterCollapsing; } - outIndices = indices; - outIndices.resize(firstDimToCollapse + 1); - return success(); + + // Compute the remaining trailing index/offset required for reading from + // the collapsed memref: + // + // offset = 0 + // for (i = firstDimToCollapse; i < outputRank; ++i) + // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] + // + // For this example: + // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) : + // memref<1x43x2xi32>, vector<1x2xi32> + // which would be collapsed to: + // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) : + // memref<1x86xi32>, vector<2xi32> + // one would get the following offset: + // %offset = %arg0 * 43 + OpFoldResult collapsedOffset = + rewriter.create(loc, 0).getResult(); + + auto collapsedStrides = computeSuffixProduct( + ArrayRef(shape.begin() + firstDimToCollapse, shape.end())); + + // Compute the collapsed offset. + auto &&[collapsedExpr, collapsedVals] = + computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse); + collapsedOffset = affine::makeComposedFoldedAffineApply( + rewriter, loc, collapsedExpr, collapsedVals); + + if (collapsedOffset.is()) { + indicesAfterCollapsing.push_back(collapsedOffset.get()); + } else { + indicesAfterCollapsing.push_back(rewriter.create( + loc, *getConstantIntValue(collapsedOffset))); + } + + return indicesAfterCollapsing; } namespace { @@ -594,54 +630,9 @@ public: AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); // 2.2 New indices - // If all the collapsed indices are zero then no extra logic is needed. - // Otherwise, a new offset/index has to be computed. - SmallVector collapsedIndices; - if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), - firstDimToCollapse, - collapsedIndices))) { - // Copy all the leading indices. - SmallVector indices = transferReadOp.getIndices(); - collapsedIndices.append(indices.begin(), - indices.begin() + firstDimToCollapse); - - // Compute the remaining trailing index/offset required for reading from - // the collapsed memref: - // - // offset = 0 - // for (i = firstDimToCollapse; i < outputRank; ++i) - // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] - // - // For this example: - // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) : - // memref<1x43x2xi32>, vector<1x2xi32> - // which would be collapsed to: - // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) : - // memref<1x86xi32>, vector<2xi32> - // one would get the following offset: - // %offset = %arg0 * 43 - OpFoldResult collapsedOffset = - rewriter.create(loc, 0).getResult(); - - auto sourceShape = sourceType.getShape(); - auto collapsedStrides = computeSuffixProduct(ArrayRef( - sourceShape.begin() + firstDimToCollapse, sourceShape.end())); - - // Compute the collapsed offset. - ArrayRef indicesToCollapse(indices.begin() + firstDimToCollapse, - indices.end()); - auto &&[collapsedExpr, collapsedVals] = computeLinearIndex( - collapsedOffset, collapsedStrides, indicesToCollapse); - collapsedOffset = affine::makeComposedFoldedAffineApply( - rewriter, loc, collapsedExpr, collapsedVals); - - if (collapsedOffset.is()) { - collapsedIndices.push_back(collapsedOffset.get()); - } else { - collapsedIndices.push_back(rewriter.create( - loc, *getConstantIntValue(collapsedOffset))); - } - } + SmallVector collapsedIndices = + getCollapsedIndices(rewriter, loc, sourceType.getShape(), + transferReadOp.getIndices(), firstDimToCollapse); // 3. Create new vector.transfer_read that reads from the collapsed memref VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, @@ -697,8 +688,7 @@ public: return failure(); if (!vector::isContiguousSlice(sourceType, vectorType)) return failure(); - int64_t firstContiguousInnerDim = - sourceType.getRank() - vectorType.getRank(); + int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) return failure(); @@ -706,22 +696,23 @@ public: return failure(); if (transferWriteOp.getMask()) return failure(); - SmallVector collapsedIndices; - if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(), - firstContiguousInnerDim, - collapsedIndices))) - return failure(); + + SmallVector collapsedIndices = + getCollapsedIndices(rewriter, loc, sourceType.getShape(), + transferWriteOp.getIndices(), firstDimToCollapse); Value collapsedSource = - collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); + collapseInnerDims(rewriter, loc, source, firstDimToCollapse); MemRefType collapsedSourceType = cast(collapsedSource.getType()); int64_t collapsedRank = collapsedSourceType.getRank(); - assert(collapsedRank == firstContiguousInnerDim + 1); + assert(collapsedRank == firstDimToCollapse + 1); + SmallVector dimExprs{ - getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; + getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; auto collapsedMap = AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); + VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, vectorType.getElementType()); Value flatVector = diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 788ae9ac044e..65bf0b9335d2 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -471,16 +471,16 @@ func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, str } // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> -// CHECK-LABEL: func.func @regression_non_contiguous_dim_read( -// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> -// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() +// CHECK-LABEL: func.func @regression_non_contiguous_dim_read( +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> +// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() // CHECK-128B-LABEL: func @regression_non_contiguous_dim_read( // CHECK-128B: memref.collapse_shape // ----- -func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>, +func.func @regression_non_contiguous_dim_write(%value : vector<2x2xf32>, %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, %idx0 : index, %idx1 : index) { %c0 = arith.constant 0 : index @@ -488,8 +488,35 @@ func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>, return } -// CHECK-LABEL: func.func @unsupported_non_contiguous_dim_write( -// CHECK-NOT: memref.collapse_shape +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-LABEL: func.func @regression_non_contiguous_dim_write( +// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> -// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write( -// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-LABEL: func @regression_non_contiguous_dim_write( +// CHECK-128B: memref.collapse_shape + +// ----- + +func.func @negative_out_of_bound_transfer_read( + %arg : memref>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} : + memref>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} +// CHECK: func.func @negative_out_of_bound_transfer_read +// CHECK-NOT: memref.collapse_shape + +// ----- + +func.func @negative_out_of_bound_transfer_write( + %arg : memref>, %vec : vector<1x1x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} : + vector<1x1x3x2xi8>, memref> + return +} +// CHECK: func.func @negative_out_of_bound_transfer_write +// CHECK-NOT: memref.collapse_shape