mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 19:44:38 +08:00
[mlir][vector] Teach TransferOptimization to look through trivial aliases (#87805)
This allows `TransferOptimization` to eliminate and forward stores that are to trivial aliases (rather than just to identical memref values). A trivial aliases is (currently) defined as: 1. A `memref.cast` 2. A `memref.subview` with a zero offset and unit strides 3. A chain of 1 and 2
This commit is contained in:
@@ -22,6 +22,9 @@ namespace mlir {
|
||||
|
||||
class MemRefType;
|
||||
|
||||
/// A value with a memref type.
|
||||
using MemrefValue = TypedValue<BaseMemRefType>;
|
||||
|
||||
namespace memref {
|
||||
|
||||
/// Returns true, if the memref type has static shapes and represents a
|
||||
@@ -93,6 +96,20 @@ computeStridesIRBlock(Location loc, OpBuilder &builder,
|
||||
return computeSuffixProductIRBlock(loc, builder, sizes);
|
||||
}
|
||||
|
||||
/// Walk up the source chain until an operation that changes/defines the view of
|
||||
/// memory is found (i.e. skip operations that alias the entire view).
|
||||
MemrefValue skipFullyAliasingOperations(MemrefValue source);
|
||||
|
||||
/// Checks if two (memref) values are the same or are statically known to alias
|
||||
/// the same region of memory.
|
||||
inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
|
||||
return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
|
||||
}
|
||||
|
||||
/// Walk up the source chain until something an op other than a `memref.subview`
|
||||
/// or `memref.cast` is found.
|
||||
MemrefValue skipSubViewsAndCasts(MemrefValue source);
|
||||
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -178,5 +178,35 @@ computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
|
||||
return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
|
||||
}
|
||||
|
||||
MemrefValue skipFullyAliasingOperations(MemrefValue source) {
|
||||
while (auto op = source.getDefiningOp()) {
|
||||
if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
|
||||
subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
|
||||
// A `memref.subview` with an all zero offset, and all unit strides, still
|
||||
// points to the same memory.
|
||||
source = cast<MemrefValue>(subViewOp.getSource());
|
||||
} else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
|
||||
// A `memref.cast` still points to the same memory.
|
||||
source = castOp.getSource();
|
||||
} else {
|
||||
return source;
|
||||
}
|
||||
}
|
||||
return source;
|
||||
}
|
||||
|
||||
MemrefValue skipSubViewsAndCasts(MemrefValue source) {
|
||||
while (auto op = source.getDefiningOp()) {
|
||||
if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
|
||||
source = cast<MemrefValue>(subView.getSource());
|
||||
} else if (auto cast = dyn_cast<memref::CastOp>(op)) {
|
||||
source = cast.getSource();
|
||||
} else {
|
||||
return source;
|
||||
}
|
||||
}
|
||||
return source;
|
||||
}
|
||||
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
@@ -104,10 +105,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
|
||||
<< "\n");
|
||||
llvm::SmallVector<Operation *, 8> blockingAccesses;
|
||||
Operation *firstOverwriteCandidate = nullptr;
|
||||
Value source = write.getSource();
|
||||
// Skip subview ops.
|
||||
while (auto subView = source.getDefiningOp<memref::SubViewOp>())
|
||||
source = subView.getSource();
|
||||
Value source =
|
||||
memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
|
||||
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
|
||||
source.getUsers().end());
|
||||
llvm::SmallDenseSet<Operation *, 32> processed;
|
||||
@@ -116,8 +115,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
|
||||
// If the user has already been processed skip.
|
||||
if (!processed.insert(user).second)
|
||||
continue;
|
||||
if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
|
||||
users.append(subView->getUsers().begin(), subView->getUsers().end());
|
||||
if (isa<memref::SubViewOp, memref::CastOp>(user)) {
|
||||
users.append(user->getUsers().begin(), user->getUsers().end());
|
||||
continue;
|
||||
}
|
||||
if (isMemoryEffectFree(user))
|
||||
@@ -126,7 +125,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
|
||||
continue;
|
||||
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
|
||||
// Check candidate that can override the store.
|
||||
if (write.getSource() == nextWrite.getSource() &&
|
||||
if (memref::isSameViewOrTrivialAlias(
|
||||
cast<MemrefValue>(nextWrite.getSource()),
|
||||
cast<MemrefValue>(write.getSource())) &&
|
||||
checkSameValueWAW(nextWrite, write) &&
|
||||
postDominators.postDominates(nextWrite, write)) {
|
||||
if (firstOverwriteCandidate == nullptr ||
|
||||
@@ -191,10 +192,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
||||
<< "\n");
|
||||
SmallVector<Operation *, 8> blockingWrites;
|
||||
vector::TransferWriteOp lastwrite = nullptr;
|
||||
Value source = read.getSource();
|
||||
// Skip subview ops.
|
||||
while (auto subView = source.getDefiningOp<memref::SubViewOp>())
|
||||
source = subView.getSource();
|
||||
Value source =
|
||||
memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
|
||||
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
|
||||
source.getUsers().end());
|
||||
llvm::SmallDenseSet<Operation *, 32> processed;
|
||||
@@ -203,12 +202,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
||||
// If the user has already been processed skip.
|
||||
if (!processed.insert(user).second)
|
||||
continue;
|
||||
if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
|
||||
users.append(subView->getUsers().begin(), subView->getUsers().end());
|
||||
continue;
|
||||
}
|
||||
if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
|
||||
users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
|
||||
if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
|
||||
users.append(user->getUsers().begin(), user->getUsers().end());
|
||||
continue;
|
||||
}
|
||||
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
|
||||
@@ -221,7 +216,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
||||
cast<VectorTransferOpInterface>(read.getOperation()),
|
||||
/*testDynamicValueUsingBounds=*/true))
|
||||
continue;
|
||||
if (write.getSource() == read.getSource() &&
|
||||
if (memref::isSameViewOrTrivialAlias(
|
||||
cast<MemrefValue>(read.getSource()),
|
||||
cast<MemrefValue>(write.getSource())) &&
|
||||
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
|
||||
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
|
||||
lastwrite = write;
|
||||
|
||||
@@ -485,3 +485,33 @@ func.func @forward_dead_constant_splat_store_with_masking_negative_3(%buffer : m
|
||||
vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// Here each read/write is to a different subview, but they all point to exact
|
||||
// same bit of memory (just through casts and subviews with unit strides and
|
||||
// zero offsets).
|
||||
// CHECK-LABEL: func @forward_and_eliminate_stores_through_trivial_aliases
|
||||
// CHECK-NOT: vector.transfer_write
|
||||
// CHECK-NOT: vector.transfer_read
|
||||
// CHECK: scf.for
|
||||
// CHECK: }
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK: return
|
||||
func.func @forward_and_eliminate_stores_through_trivial_aliases(
|
||||
%buffer : memref<?x?xf32>, %vec: vector<[8]x[8]xf32>, %size: index, %a_size: index, %another_size: index
|
||||
) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
vector.transfer_write %vec, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
|
||||
%direct_subview = memref.subview %buffer[0, 0] [%a_size, %a_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
|
||||
%cast = memref.cast %direct_subview : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32>
|
||||
%subview_of_cast = memref.subview %cast[0, 0] [%another_size, %another_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
|
||||
%21 = vector.transfer_read %direct_subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<[8]x[8]xf32>
|
||||
%23 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %21) -> (vector<[8]x[8]xf32>) {
|
||||
%24 = arith.addf %arg3, %arg3 : vector<[8]x[8]xf32>
|
||||
scf.yield %24 : vector<[8]x[8]xf32>
|
||||
}
|
||||
vector.transfer_write %23, %subview_of_cast[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user