mirror of
https://github.com/intel/llvm.git
synced 2026-02-01 08:56:15 +08:00
Enable producer-consumer fusion for liveout memrefs if consumer read region matches producer write region.
-- PiperOrigin-RevId: 241517207
This commit is contained in:
@@ -1244,11 +1244,10 @@ static uint64_t getSliceIterationCount(
|
||||
|
||||
// Checks if node 'srcId' (which writes to a live out memref), can be safely
|
||||
// fused into node 'dstId'. Returns true if the following conditions are met:
|
||||
// *) 'srcNode' writes only writes to live out 'memref'.
|
||||
// *) 'srcNode' only writes to live out 'memref'.
|
||||
// *) 'srcNode' has exaclty one output edge on 'memref' (which is to 'dstId').
|
||||
// *) 'dstNode' does write to 'memref'.
|
||||
// *) 'dstNode's write region to 'memref' is a super set of 'srcNode's write
|
||||
// region to 'memref'.
|
||||
// *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's
|
||||
// write region to 'memref'.
|
||||
// TODO(andydavis) Generalize this to handle more live in/out cases.
|
||||
static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
|
||||
Value *memref,
|
||||
@@ -1256,13 +1255,17 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
|
||||
auto *srcNode = mdg->getNode(srcId);
|
||||
auto *dstNode = mdg->getNode(dstId);
|
||||
|
||||
// Gather all memrefs from 'srcNode' store ops.
|
||||
DenseSet<Value *> storeMemrefs;
|
||||
for (auto *storeOpInst : srcNode->stores) {
|
||||
storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
|
||||
}
|
||||
// Return false if any of the following are true:
|
||||
// *) 'srcNode' writes to a live in/out memref other than 'memref'.
|
||||
// *) 'srcNode' has more than one output edge on 'memref'.
|
||||
// *) 'dstNode' does not write to 'memref'.
|
||||
if (srcNode->getStoreOpCount(memref) != 1 ||
|
||||
mdg->getOutEdgeCount(srcNode->id, memref) != 1 ||
|
||||
dstNode->getStoreOpCount(memref) == 0)
|
||||
// Check that all stores are to the same memref.
|
||||
if (storeMemrefs.size() != 1 ||
|
||||
mdg->getOutEdgeCount(srcNode->id, memref) != 1)
|
||||
return false;
|
||||
// Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
|
||||
auto *srcStoreOpInst = srcNode->stores.front();
|
||||
@@ -1280,23 +1283,26 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
|
||||
if (!srcNumElements.hasValue())
|
||||
return false;
|
||||
|
||||
// Compute MemRefRegion 'dstWriteRegion' for 'dstStoreOpInst' on 'memref'.
|
||||
SmallVector<Operation *, 2> dstStoreOps;
|
||||
dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
|
||||
// Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'.
|
||||
// TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for
|
||||
// each store op in 'dstStoreOps').
|
||||
auto *dstStoreOpInst = dstStoreOps[0];
|
||||
MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc());
|
||||
if (failed(dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0))) {
|
||||
SmallVector<Operation *, 2> dstStoreOps;
|
||||
dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
|
||||
SmallVector<Operation *, 2> dstLoadOps;
|
||||
dstNode->getLoadOpsForMemref(memref, &dstLoadOps);
|
||||
|
||||
auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0];
|
||||
MemRefRegion dstRegion(dstOpInst->getLoc());
|
||||
if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Unable to compute MemRefRegion for dest operation\n.");
|
||||
return false;
|
||||
}
|
||||
SmallVector<int64_t, 4> dstShape;
|
||||
// Query 'dstWriteRegion' for 'dstShape' and 'dstNumElements'.
|
||||
// by 'dstStoreOpInst' at depth 'dstLoopDepth'.
|
||||
// Query 'dstRegion' for 'dstShape' and 'dstNumElements'.
|
||||
// by 'dstOpInst' at depth 'dstLoopDepth'.
|
||||
Optional<int64_t> dstNumElements =
|
||||
dstWriteRegion.getConstantBoundingSizeAndShape(&dstShape);
|
||||
dstRegion.getConstantBoundingSizeAndShape(&dstShape);
|
||||
if (!dstNumElements.hasValue())
|
||||
return false;
|
||||
|
||||
|
||||
@@ -1261,15 +1261,18 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) {
|
||||
affine.for %i0 = 0 to 10 {
|
||||
store %cf7, %arg0[%i0] : memref<10xf32>
|
||||
}
|
||||
affine.for %i1 = 0 to 10 {
|
||||
affine.for %i1 = 0 to 9 {
|
||||
%v0 = load %arg0[%i1] : memref<10xf32>
|
||||
}
|
||||
// This tests that the loop nest '%i0' should not be removed after fusion
|
||||
// because it writes to memref argument '%arg0'.
|
||||
// because it writes to memref argument '%arg0', and its read region
|
||||
// does not cover its write region (so fusion would shrink the write region
|
||||
// in the fused loop nest, so complete live out data region would not
|
||||
// be written).
|
||||
// CHECK: affine.for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to 10 {
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to 9 {
|
||||
// CHECK-NEXT: %0 = load %arg0[%i1] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
@@ -1278,6 +1281,29 @@ func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @should_fuse_live_out_arg(%arg0: memref<10xf32>) {
|
||||
func @should_fuse_live_out_arg(%arg0: memref<10xf32>) {
|
||||
%cf7 = constant 7.0 : f32
|
||||
|
||||
affine.for %i0 = 0 to 10 {
|
||||
store %cf7, %arg0[%i0] : memref<10xf32>
|
||||
}
|
||||
affine.for %i1 = 0 to 10 {
|
||||
%v0 = load %arg0[%i1] : memref<10xf32>
|
||||
}
|
||||
// The read/write regions for memref '%arg0' are the same for both
|
||||
// loops, so they should fuse.
|
||||
|
||||
// CHECK: affine.for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: %0 = load %arg0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @should_not_fuse_escaping_memref() -> memref<10xf32>
|
||||
func @should_not_fuse_escaping_memref() -> memref<10xf32> {
|
||||
%cf7 = constant 7.0 : f32
|
||||
@@ -1285,7 +1311,7 @@ func @should_not_fuse_escaping_memref() -> memref<10xf32> {
|
||||
affine.for %i0 = 0 to 10 {
|
||||
store %cf7, %m[%i0] : memref<10xf32>
|
||||
}
|
||||
affine.for %i1 = 0 to 10 {
|
||||
affine.for %i1 = 0 to 9 {
|
||||
%v0 = load %m[%i1] : memref<10xf32>
|
||||
}
|
||||
// This tests that the loop nest '%i0' should not be removed after fusion
|
||||
@@ -1294,7 +1320,7 @@ func @should_not_fuse_escaping_memref() -> memref<10xf32> {
|
||||
// CHECK: affine.for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to 10 {
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to 9 {
|
||||
// CHECK-NEXT: %1 = load %0[%i1] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %0 : memref<10xf32>
|
||||
@@ -2410,3 +2436,56 @@ func @affine_2mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>, %arg4: memref<1024x1024xf32>) {
|
||||
affine.for %i0 = 0 to 1024 {
|
||||
affine.for %i1 = 0 to 1024 {
|
||||
affine.for %i2 = 0 to 1024 {
|
||||
%0 = load %arg1[%i2, %i1] : memref<1024x1024xf32>
|
||||
%1 = load %arg0[%i0, %i2] : memref<1024x1024xf32>
|
||||
%2 = mulf %1, %0 : f32
|
||||
%3 = load %arg2[%i0, %i1] : memref<1024x1024xf32>
|
||||
%4 = addf %3, %2 : f32
|
||||
store %4, %arg2[%i0, %i1] : memref<1024x1024xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
affine.for %i3 = 0 to 1024 {
|
||||
affine.for %i4 = 0 to 1024 {
|
||||
affine.for %i5 = 0 to 1024 {
|
||||
%5 = load %arg3[%i5, %i4] : memref<1024x1024xf32>
|
||||
%6 = load %arg2[%i3, %i5] : memref<1024x1024xf32>
|
||||
%7 = mulf %6, %5 : f32
|
||||
%8 = load %arg4[%i3, %i4] : memref<1024x1024xf32>
|
||||
%9 = addf %8, %7 : f32
|
||||
store %9, %arg4[%i3, %i4] : memref<1024x1024xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: affine.for %i0 = 0 to 1024 {
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to 1024 {
|
||||
// CHECK-NEXT: affine.for %i2 = 0 to 1024 {
|
||||
// CHECK-NEXT: %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %2 = mulf %1, %0 : f32
|
||||
// CHECK-NEXT: %3 = load %arg2[%i0, %i1] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %4 = addf %3, %2 : f32
|
||||
// CHECK-NEXT: store %4, %arg2[%i0, %i1] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: affine.for %i3 = 0 to 1024 {
|
||||
// CHECK-NEXT: affine.for %i4 = 0 to 1024 {
|
||||
// CHECK-NEXT: %5 = load %arg3[%i4, %i3] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %6 = load %arg2[%i0, %i4] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %7 = mulf %6, %5 : f32
|
||||
// CHECK-NEXT: %8 = load %arg4[%i0, %i3] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %9 = addf %8, %7 : f32
|
||||
// CHECK-NEXT: store %9, %arg4[%i0, %i3] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user