diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 77e5a6aa04f8..d7e1b610022b 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -365,21 +365,20 @@ public: // Computes and returns an insertion point instruction, before which the // the fused loop nest can be inserted while preserving // dependences. Returns nullptr if no such insertion point is found. - Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId, - Value *memrefToSkip) { + Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { if (outEdges.count(srcId) == 0) return getNode(dstId)->inst; // Build set of insts in range (srcId, dstId) which depend on 'srcId'. SmallPtrSet srcDepInsts; for (auto &outEdge : outEdges[srcId]) - if (outEdge.id != dstId && outEdge.value != memrefToSkip) + if (outEdge.id != dstId) srcDepInsts.insert(getNode(outEdge.id)->inst); // Build set of insts in range (srcId, dstId) on which 'dstId' depends. SmallPtrSet dstDepInsts; for (auto &inEdge : inEdges[dstId]) - if (inEdge.id != srcId && inEdge.value != memrefToSkip) + if (inEdge.id != srcId) dstDepInsts.insert(getNode(inEdge.id)->inst); Instruction *srcNodeInst = getNode(srcId)->inst; @@ -1366,18 +1365,24 @@ static bool isFusionProfitable(Instruction *srcOpInst, struct GreedyFusion { public: MemRefDependenceGraph *mdg; - SmallVector worklist; + SmallVector worklist; + llvm::SmallDenseSet worklistSet; GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) { // Initialize worklist with nodes from 'mdg'. + // TODO(andydavis) Add a priority queue for prioritizing nodes by different + // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). worklist.resize(mdg->nodes.size()); std::iota(worklist.begin(), worklist.end(), 0); + worklistSet.insert(worklist.begin(), worklist.end()); } void run(unsigned localBufSizeThreshold, Optional fastMemorySpace) { while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); + worklistSet.erase(dstId); + // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) continue; @@ -1437,8 +1442,8 @@ public: // Compute an instruction list insertion point for the fused loop // nest which preserves dependences. - Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint( - srcNode->id, dstNode->id, memref); + Instruction *insertPointInst = + mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); if (insertPointInst == nullptr) continue; @@ -1516,6 +1521,22 @@ public: if (mdg->canRemoveNode(srcNode->id)) { mdg->removeNode(srcNode->id); srcNode->inst->erase(); + } else { + // Add remaining users of 'oldMemRef' back on the worklist (if not + // already there), as its replacement with a local/private memref + // has reduced dependences on 'oldMemRef' which may have created + // new fusion opportunities. + if (mdg->outEdges.count(srcNode->id) > 0) { + SmallVector oldOutEdges = + mdg->outEdges[srcNode->id]; + for (auto &outEdge : oldOutEdges) { + if (outEdge.value == memref && + worklistSet.count(outEdge.id) == 0) { + worklist.push_back(outEdge.id); + worklistSet.insert(outEdge.id); + } + } + } } } } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index bb8dd0db73e3..57c1ec4ceed9 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -340,10 +340,9 @@ func @should_not_fuse_would_create_cycle() { } // ----- -// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) -// CHECK-LABEL: func @should_fuse_across_waw_dep_with_private_memref() { -func @should_fuse_across_waw_dep_with_private_memref() { +// CHECK-LABEL: func @should_not_fuse_across_waw_dep() { +func @should_not_fuse_across_waw_dep() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -358,16 +357,13 @@ func @should_fuse_across_waw_dep_with_private_memref() { } // Fusing loop %i0 to %i2 would violate the WAW dependence between %i0 and %i1 // CHECK: for %i0 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i0] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i1 = 0 to 10 { - // CHECK-NEXT: store %cst, %1[%i1] : memref<10xf32> + // CHECK-NEXT: store %cst, %0[%i1] : memref<10xf32> // CHECK-NEXT: } // CHECK: for %i2 = 0 to 10 { - // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i2, %i2) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> - // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i2, %i2) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: %1 = load %0[%i2] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1234,17 +1230,17 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() { // by loops %i1 and %i2. // CHECK-DAG: %0 = alloc() : memref<1xf32> // CHECK-DAG: %1 = alloc() : memref<1xf32> - // CHECK: for %i0 = 0 to 17 { + // CHECK: for %i0 = 0 to 82 { // CHECK-NEXT: %2 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: store %cst, %0[%2] : memref<1xf32> + // CHECK-NEXT: store %cst, %1[%2] : memref<1xf32> // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) - // CHECK-NEXT: %4 = load %0[%3] : memref<1xf32> + // CHECK-NEXT: %4 = load %1[%3] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: for %i1 = 0 to 82 { + // CHECK-NEXT: for %i1 = 0 to 17 { // CHECK-NEXT: %5 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: store %cst, %1[%5] : memref<1xf32> + // CHECK-NEXT: store %cst, %0[%5] : memref<1xf32> // CHECK-NEXT: %6 = affine_apply [[MAP0]](%i1, %i1) - // CHECK-NEXT: %7 = load %1[%6] : memref<1xf32> + // CHECK-NEXT: %7 = load %0[%6] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1652,3 +1648,50 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> // CHECK-NEXT: return return } + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (-d0 + d1) + +// CHECK-LABEL: func @should_fuse_after_private_memref_creation() { +func @should_fuse_after_private_memref_creation() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + + %cf7 = constant 7.0 : f32 + + for %i0 = 0 to 10 { + store %cf7, %a[%i0] : memref<10xf32> + } + for %i1 = 0 to 10 { + %v0 = load %a[%i1] : memref<10xf32> + store %v0, %b[%i1] : memref<10xf32> + } + for %i2 = 0 to 10 { + %v1 = load %a[%i2] : memref<10xf32> + store %v1, %b[%i2] : memref<10xf32> + } + + // On the first visit to '%i2', the fusion algorithm can not fuse loop nest + // '%i0' into '%i2' because of the dependences '%i0' and '%i2' each have on + // '%i1'. However, once the loop nest '%i0' is fused into '%i1' with a + // private memref, the dependence between '%i0' and '%i1' on memref '%a' no + // longer exists, so '%i0' can now be fused into '%i2'. + + // CHECK: for %i0 = 0 to 10 { + // CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: store %cst, %1[%3] : memref<1xf32> + // CHECK-NEXT: %4 = affine_apply [[MAP0]](%i0, %i0) + // CHECK-NEXT: %5 = load %1[%4] : memref<1xf32> + // CHECK-NEXT: store %5, %2[%i0] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: for %i1 = 0 to 10 { + // CHECK-NEXT: %6 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: store %cst, %0[%6] : memref<1xf32> + // CHECK-NEXT: %7 = affine_apply [[MAP0]](%i1, %i1) + // CHECK-NEXT: %8 = load %0[%7] : memref<1xf32> + // CHECK-NEXT: store %8, %2[%i1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: return + return +}