mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 03:56:16 +08:00
[mlir][vector] Improve flattening vector.transfer_write ops. (#94051)
We can flatten the transfer ops even when the collapsed indices are not zeros. We can compute it. It is already supported in vector.transfer_read cases. The revision refactors the logic and reuse it in transfer_write cases.
This commit is contained in:
@@ -505,25 +505,61 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
|
||||
return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
|
||||
}
|
||||
|
||||
/// Checks that the indices corresponding to dimensions starting at
|
||||
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
|
||||
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
|
||||
/// TODO: Extract the logic that writes to outIndices so that this method
|
||||
/// simply checks one pre-condition.
|
||||
static LogicalResult
|
||||
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
|
||||
SmallVector<Value> &outIndices) {
|
||||
int64_t rank = indices.size();
|
||||
if (firstDimToCollapse >= rank)
|
||||
return failure();
|
||||
for (int64_t i = firstDimToCollapse; i < rank; ++i) {
|
||||
std::optional<int64_t> cst = getConstantIntValue(indices[i]);
|
||||
if (!cst || cst.value() != 0)
|
||||
return failure();
|
||||
/// Returns the new indices that collapses the inner dimensions starting from
|
||||
/// the `firstDimToCollapse` dimension.
|
||||
static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
|
||||
Location loc,
|
||||
ArrayRef<int64_t> shape,
|
||||
ValueRange indices,
|
||||
int64_t firstDimToCollapse) {
|
||||
assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
|
||||
|
||||
// If all the collapsed indices are zero then no extra logic is needed.
|
||||
// Otherwise, a new offset/index has to be computed.
|
||||
SmallVector<Value> indicesAfterCollapsing(
|
||||
indices.begin(), indices.begin() + firstDimToCollapse);
|
||||
SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
|
||||
indices.end());
|
||||
if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
|
||||
indicesAfterCollapsing.push_back(indicesToCollapse[0]);
|
||||
return indicesAfterCollapsing;
|
||||
}
|
||||
outIndices = indices;
|
||||
outIndices.resize(firstDimToCollapse + 1);
|
||||
return success();
|
||||
|
||||
// Compute the remaining trailing index/offset required for reading from
|
||||
// the collapsed memref:
|
||||
//
|
||||
// offset = 0
|
||||
// for (i = firstDimToCollapse; i < outputRank; ++i)
|
||||
// offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
|
||||
//
|
||||
// For this example:
|
||||
// %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
|
||||
// memref<1x43x2xi32>, vector<1x2xi32>
|
||||
// which would be collapsed to:
|
||||
// %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
|
||||
// memref<1x86xi32>, vector<2xi32>
|
||||
// one would get the following offset:
|
||||
// %offset = %arg0 * 43
|
||||
OpFoldResult collapsedOffset =
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
|
||||
|
||||
auto collapsedStrides = computeSuffixProduct(
|
||||
ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
|
||||
|
||||
// Compute the collapsed offset.
|
||||
auto &&[collapsedExpr, collapsedVals] =
|
||||
computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
|
||||
collapsedOffset = affine::makeComposedFoldedAffineApply(
|
||||
rewriter, loc, collapsedExpr, collapsedVals);
|
||||
|
||||
if (collapsedOffset.is<Value>()) {
|
||||
indicesAfterCollapsing.push_back(collapsedOffset.get<Value>());
|
||||
} else {
|
||||
indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
|
||||
loc, *getConstantIntValue(collapsedOffset)));
|
||||
}
|
||||
|
||||
return indicesAfterCollapsing;
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -594,54 +630,9 @@ public:
|
||||
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
|
||||
|
||||
// 2.2 New indices
|
||||
// If all the collapsed indices are zero then no extra logic is needed.
|
||||
// Otherwise, a new offset/index has to be computed.
|
||||
SmallVector<Value> collapsedIndices;
|
||||
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
|
||||
firstDimToCollapse,
|
||||
collapsedIndices))) {
|
||||
// Copy all the leading indices.
|
||||
SmallVector<Value> indices = transferReadOp.getIndices();
|
||||
collapsedIndices.append(indices.begin(),
|
||||
indices.begin() + firstDimToCollapse);
|
||||
|
||||
// Compute the remaining trailing index/offset required for reading from
|
||||
// the collapsed memref:
|
||||
//
|
||||
// offset = 0
|
||||
// for (i = firstDimToCollapse; i < outputRank; ++i)
|
||||
// offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
|
||||
//
|
||||
// For this example:
|
||||
// %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
|
||||
// memref<1x43x2xi32>, vector<1x2xi32>
|
||||
// which would be collapsed to:
|
||||
// %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
|
||||
// memref<1x86xi32>, vector<2xi32>
|
||||
// one would get the following offset:
|
||||
// %offset = %arg0 * 43
|
||||
OpFoldResult collapsedOffset =
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
|
||||
|
||||
auto sourceShape = sourceType.getShape();
|
||||
auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
|
||||
sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
|
||||
|
||||
// Compute the collapsed offset.
|
||||
ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
|
||||
indices.end());
|
||||
auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
|
||||
collapsedOffset, collapsedStrides, indicesToCollapse);
|
||||
collapsedOffset = affine::makeComposedFoldedAffineApply(
|
||||
rewriter, loc, collapsedExpr, collapsedVals);
|
||||
|
||||
if (collapsedOffset.is<Value>()) {
|
||||
collapsedIndices.push_back(collapsedOffset.get<Value>());
|
||||
} else {
|
||||
collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
|
||||
loc, *getConstantIntValue(collapsedOffset)));
|
||||
}
|
||||
}
|
||||
SmallVector<Value> collapsedIndices =
|
||||
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
|
||||
transferReadOp.getIndices(), firstDimToCollapse);
|
||||
|
||||
// 3. Create new vector.transfer_read that reads from the collapsed memref
|
||||
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
|
||||
@@ -697,8 +688,7 @@ public:
|
||||
return failure();
|
||||
if (!vector::isContiguousSlice(sourceType, vectorType))
|
||||
return failure();
|
||||
int64_t firstContiguousInnerDim =
|
||||
sourceType.getRank() - vectorType.getRank();
|
||||
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
|
||||
// TODO: generalize this pattern, relax the requirements here.
|
||||
if (transferWriteOp.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
@@ -706,22 +696,23 @@ public:
|
||||
return failure();
|
||||
if (transferWriteOp.getMask())
|
||||
return failure();
|
||||
SmallVector<Value> collapsedIndices;
|
||||
if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
|
||||
firstContiguousInnerDim,
|
||||
collapsedIndices)))
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> collapsedIndices =
|
||||
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
|
||||
transferWriteOp.getIndices(), firstDimToCollapse);
|
||||
|
||||
Value collapsedSource =
|
||||
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
|
||||
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
|
||||
MemRefType collapsedSourceType =
|
||||
cast<MemRefType>(collapsedSource.getType());
|
||||
int64_t collapsedRank = collapsedSourceType.getRank();
|
||||
assert(collapsedRank == firstContiguousInnerDim + 1);
|
||||
assert(collapsedRank == firstDimToCollapse + 1);
|
||||
|
||||
SmallVector<AffineExpr, 1> dimExprs{
|
||||
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
|
||||
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
|
||||
auto collapsedMap =
|
||||
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
|
||||
|
||||
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
|
||||
vectorType.getElementType());
|
||||
Value flatVector =
|
||||
|
||||
@@ -471,16 +471,16 @@ func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, str
|
||||
}
|
||||
|
||||
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
|
||||
// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
|
||||
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
|
||||
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
|
||||
// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
|
||||
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
|
||||
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
|
||||
|
||||
// CHECK-128B-LABEL: func @regression_non_contiguous_dim_read(
|
||||
// CHECK-128B: memref.collapse_shape
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
|
||||
func.func @regression_non_contiguous_dim_write(%value : vector<2x2xf32>,
|
||||
%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
|
||||
%idx0 : index, %idx1 : index) {
|
||||
%c0 = arith.constant 0 : index
|
||||
@@ -488,8 +488,35 @@ func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @unsupported_non_contiguous_dim_write(
|
||||
// CHECK-NOT: memref.collapse_shape
|
||||
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
|
||||
// CHECK-LABEL: func.func @regression_non_contiguous_dim_write(
|
||||
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
|
||||
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
|
||||
|
||||
// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write(
|
||||
// CHECK-128B-NOT: memref.collapse_shape
|
||||
// CHECK-128B-LABEL: func @regression_non_contiguous_dim_write(
|
||||
// CHECK-128B: memref.collapse_shape
|
||||
|
||||
// -----
|
||||
|
||||
func.func @negative_out_of_bound_transfer_read(
|
||||
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 0 : i8
|
||||
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} :
|
||||
memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
|
||||
return %v : vector<5x4x3x2xi8>
|
||||
}
|
||||
// CHECK: func.func @negative_out_of_bound_transfer_read
|
||||
// CHECK-NOT: memref.collapse_shape
|
||||
|
||||
// -----
|
||||
|
||||
func.func @negative_out_of_bound_transfer_write(
|
||||
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x3x2xi8>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} :
|
||||
vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @negative_out_of_bound_transfer_write
|
||||
// CHECK-NOT: memref.collapse_shape
|
||||
|
||||
Reference in New Issue
Block a user