From a78edcda5bb5ba6d89d2efd3004becb7e3a9fc95 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Tue, 5 Feb 2019 06:57:08 -0800 Subject: [PATCH] Loop fusion improvements: *) After a private memref buffer is created for a fused loop nest, dependences on the old memref are reduced, which can open up fusion opportunities. In these cases, users of the old memref are added back to the worklist to be reconsidered for fusion. *) Fixed a bug in fusion insertion point dependence check where the memref being privatized was being skipped from the check. PiperOrigin-RevId: 232477853 --- mlir/lib/Transforms/LoopFusion.cpp | 35 ++++++++++--- mlir/test/Transforms/loop-fusion.mlir | 73 +++++++++++++++++++++------ 2 files changed, 86 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 77e5a6aa04f8..d7e1b610022b 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -365,21 +365,20 @@ public: // Computes and returns an insertion point instruction, before which the // the fused loop nest can be inserted while preserving // dependences. Returns nullptr if no such insertion point is found. - Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId, - Value *memrefToSkip) { + Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { if (outEdges.count(srcId) == 0) return getNode(dstId)->inst; // Build set of insts in range (srcId, dstId) which depend on 'srcId'. SmallPtrSet srcDepInsts; for (auto &outEdge : outEdges[srcId]) - if (outEdge.id != dstId && outEdge.value != memrefToSkip) + if (outEdge.id != dstId) srcDepInsts.insert(getNode(outEdge.id)->inst); // Build set of insts in range (srcId, dstId) on which 'dstId' depends. SmallPtrSet dstDepInsts; for (auto &inEdge : inEdges[dstId]) - if (inEdge.id != srcId && inEdge.value != memrefToSkip) + if (inEdge.id != srcId) dstDepInsts.insert(getNode(inEdge.id)->inst); Instruction *srcNodeInst = getNode(srcId)->inst; @@ -1366,18 +1365,24 @@ static bool isFusionProfitable(Instruction *srcOpInst, struct GreedyFusion { public: MemRefDependenceGraph *mdg; - SmallVector worklist; + SmallVector worklist; + llvm::SmallDenseSet worklistSet; GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) { // Initialize worklist with nodes from 'mdg'. + // TODO(andydavis) Add a priority queue for prioritizing nodes by different + // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). worklist.resize(mdg->nodes.size()); std::iota(worklist.begin(), worklist.end(), 0); + worklistSet.insert(worklist.begin(), worklist.end()); } void run(unsigned localBufSizeThreshold, Optional fastMemorySpace) { while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); + worklistSet.erase(dstId); + // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) continue; @@ -1437,8 +1442,8 @@ public: // Compute an instruction list insertion point for the fused loop // nest which preserves dependences. - Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint( - srcNode->id, dstNode->id, memref); + Instruction *insertPointInst = + mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); if (insertPointInst == nullptr) continue; @@ -1516,6 +1521,22 @@ public: if (mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); srcNode->inst->erase(); + } else { + // Add remaining users of 'oldMemRef' back on the worklist (if not + // already there), as its replacement with a local/private memref + // has reduced dependences on 'oldMemRef' which may have created + // new fusion opportunities. + if (mdg->outEdges.count(srcNode->id) > 0) { + SmallVector oldOutEdges = + mdg->outEdges[srcNode->id]; + for (auto &outEdge : oldOutEdges) { + if (outEdge.value == memref && + worklistSet.count(outEdge.id) == 0) { + worklist.push_back(outEdge.id); + worklistSet.insert(outEdge.id); + } + } + } } } } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index bb8dd0db73e3..57c1ec4ceed9 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -340,10 +340,9 @@ func @should_not_fuse_would_create_cycle() { } // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) -// CHECK-LABEL: func @should_fuse_across_waw_dep_with_private_memref() { -func @should_fuse_across_waw_dep_with_private_memref() { +// CHECK-LABEL: func @should_not_fuse_across_waw_dep() { +func @should_not_fuse_across_waw_dep() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -358,16 +357,13 @@ func @should_fuse_across_waw_dep_with_private_memref() { } // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i2 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i2, %i2) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i2, %i2) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1234,17 +1230,17 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { // by loops %i1 and %i2. // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK-DAG: %1 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 17 { + // CHECK: for %i0 = 0 to 82 { // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> + // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 82 { + // CHECK-NEXT: for %i1 = 0 to 17 { // CHECK-NEXT: %5 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %1[%5] : memref<1xf32> + // CHECK-NEXT: store %cst, %0[%5] : memref<1xf32> // CHECK-NEXT: %6 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %7 = load %1[%6] : memref<1xf32> + // CHECK-NEXT: %7 = load %0[%6] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1652,3 +1648,50 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> // CHECK-NEXT: return return } + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) + +// CHECK-LABEL: func @should_fuse_after_private_memref_creation() { +func @should_fuse_after_private_memref_creation() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %a[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %a[%i1] : memref<10xf32> + store %v0, %b[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v1 = load %a[%i2] : memref<10xf32> + store %v1, %b[%i2] : memref<10xf32> + } + + // On the first visit to '%i2', the fusion algorithm can not fuse loop nest + // '%i0' into '%i2' because of the dependences '%i0' and '%i2' each have on + // '%i1'. However, once the loop nest '%i0' is fused into '%i1' with a + // private memref, the dependence between '%i0' and '%i1' on memref '%a' no + // longer exists, so '%i0' can now be fused into '%i2'. + + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %cst, %1[%3] : memref<1xf32> + // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %5 = load %1[%4] : memref<1xf32> + // CHECK-NEXT: store %5, %2[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %6 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: store %cst, %0[%6] : memref<1xf32> + // CHECK-NEXT: %7 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> + // CHECK-NEXT: store %8, %2[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +}