From 0caa82e2ac53b2ff475531086dfe648fb2d6158a Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Fri, 20 Nov 2020 13:09:28 +0100 Subject: [PATCH] Revert "[mlir][Linalg] Fuse sequence of Linalg operation (on buffers)" This reverts commit f8284d21a8e294d58a0acd4b8b2e906d7a9f110c. Revert "[mlir][Linalg] NFC: Expose some utility functions used for promotion." This reverts commit 0c59f51592ef5c014352994369f5216c6376fae1. Revert "Remove unused isZero function" This reverts commit 0f9f0a4046e11c2b4c130640f343e3b2b5db08c1. Change f8284d21 led to multiple failures in IREE compilation. --- .../Dialect/Linalg/Transforms/Transforms.h | 82 ++- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 8 - mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 570 +++++++++--------- .../Dialect/Linalg/Transforms/Promotion.cpp | 20 +- .../Dialect/Linalg/Transforms/Transforms.cpp | 56 +- mlir/test/Dialect/Linalg/fusion-pattern.mlir | 29 +- mlir/test/Dialect/Linalg/fusion-sequence.mlir | 133 ---- .../Transforms/TestLinalgFusionTransforms.cpp | 45 -- mlir/tools/mlir-opt/mlir-opt.cpp | 2 - 9 files changed, 336 insertions(+), 609 deletions(-) delete mode 100644 mlir/test/Dialect/Linalg/fusion-sequence.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 87ff2a97d93f..8d531a1e343a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -37,6 +37,14 @@ struct TiledLinalgOp { SmallVector tensorResults; }; +struct TiledAndFusedLinalgOps { + LinalgOp op; + SmallVector fusedProducers; + SmallVector originalProducers; + SmallVector fusedLoops; + SmallVector unfusedLoops; +}; + /// Populates patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns, @@ -65,11 +73,14 @@ void populateLinalgBufferizePatterns(MLIRContext *context, Optional tileLinalgOp(OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options); -/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This -/// proceeds as follows: -/// - Find outer parallel loops in these ops that can be fused. -/// - Tile fusable outer parallel loops of the last operation in the sequence. -/// - Fuse the remaining operations with the tiled operation +/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in +/// three steps +/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile +/// + fuse loops). +/// - Tile just these loops of the consumer (root operation) and fuse with +/// the producer. +/// - Tile again the tiled consumer operation produced above to do rest of +/// the tiling specified by the `tilingOptions`. /// /// For example, consider the sequence of matmul below /// @@ -96,39 +107,36 @@ Optional tileLinalgOp(OpBuilder &b, LinalgOp op, /// : memref<256x32xf32> to memref<16x32xf32, #map0> /// %3 = subview %arg1[0, 0] [32, 32] [1, 1] /// : memref<32x32xf32> to memref<32x32xf32, #map1> -/// %4 = subview %arg3[0, 0] [32, 32] [1, 1] -/// : memref<32x32xf32> to memref<32x32xf32, #map1> /// linalg.matmul /// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>) /// outs(%0 : memref<16x32xf32, #map0>) -/// linalg.matmul -/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) -/// outs(%1 : memref<16x8xf32, #map0>) +/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) { +/// scf.for %arg7 = %c0 to %c32 step %c4 { +/// %4 = subview %0[0, %arg7] [16, 4] [1, 1] +/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0> +/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1] +/// : memref<32x32xf32> to memref<4x8xf32, #map0> +/// %6 = subview %1[0, %arg6] [16, 8] [1, 1] +/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0> +/// linalg.matmul +/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) +/// outs(%6 : memref<16x8xf32, #map0>) +/// } +/// scf.yield +/// } +/// scf.yield /// } /// -/// `tilingOptions` are used to tile the corresponding operation in `ops` (the -/// size of the former should be same as size of the latter. Based on how -/// tile+fuse is implemented, the fused loops are generated based on the last -/// operation in the sequence. For example, the tile sizes for the fused loops -/// is obtained from `tilingOptions.back()`. The following tiling options are -/// handled differently in tile+fuse (compared to tile only) +/// The following tiling options are handled differently in tile+fuse (compared +/// to tile only) /// - Interchange of the tiling loops is not supported right now. -/// - Only the fused loops are distributed. -struct TiledAndFusedLinalgOps { - /// Operation obtained by tiling the last operation in sequence of `ops` - /// passed to `tileAndFuseLinalgOps`. - LinalgOp op; - /// The dimension of the loops that are fused. - std::set fusedLoopDims; - /// The generated fused operations (created within the fused loops). - SmallVector fusedProducers; - /// The fused loop generated. - SmallVector fusedLoops; -}; +/// - Distribution is only done for the tile+fuse loops. The tiled loops +/// generated by the second tiling is not distributed. Optional -tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, +tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions); + const LinalgTilingOptions &tilingOptions, + const LinalgFusionOptions &fusionOptions); /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. /// This is an in-place transformation controlled by `interchangeVector`. @@ -234,20 +242,6 @@ struct LinalgPromotionOptions { } }; -/// Creates a new buffer using the `allocationFn` provided. The size of this -/// buffer is the smallest constant bounding size along each dimension that can -/// be computed for the size of the result of `subView`. Returns the allocated -/// buffer as `fullLocalView` and the view that matches the size of the result -/// of subview operation as `partialLocalView`. -struct PromotionInfo { - Value fullLocalView; - Value partialLocalView; -}; -Optional -promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView, - AllocBufferCallbackFn allocationFn, - OperationFolder *folder = nullptr); - /// Promotes the `subViews` into a new buffer allocated at the insertion point /// `b`. Promotion occurs in 3 steps: /// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 1eaf8b0e709c..f5669e383368 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -17,7 +17,6 @@ #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" using mlir::edsc::intrinsics::AffineIndexedValue; @@ -83,13 +82,6 @@ bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer); -using FusableOpDependencesTy = llvm::MapVector< - Operation *, - SmallVector>; -FusableOpDependencesTy -findAllFusableDependences(ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph); - /// Fuses producer into consumer if the producer is structurally feasible and /// the fusion would not violate dependencies. /// Implements the fusion part of the "tileAndFuse on buffers" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 45a68fcba4a2..969bea4a4549 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -178,9 +178,6 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, Value shape = en.value(); SmallVector shapeRanges(map.getNumResults(), nullptr); for (auto en2 : llvm::enumerate(map.getResults())) { - auto dimExpr = en2.value().dyn_cast(); - if (!dimExpr) - continue; if (loopDepth == en2.value().cast().getPosition()) { LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " << loopDepth << "\n"); @@ -193,18 +190,49 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, llvm_unreachable("Expect to be able to extract a shape defining loop range"); } -/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges` -/// provides the loop range information for the fused loops. The rest are -/// obtained from the producer itself, since they are not tiled + fused. -static LinalgOp fuse(OpBuilder &b, LinalgOp producer, - const DenseMap &fusedLoopsAndRanges) { +/// Fuses the producer of `producerIdx` into the loop immediately enclosing +/// `consumer`. This is achieved by "recomputing" the `producer` at the time it +/// is needed just before the `consumer. +/// +/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are +/// 2 cases: +/// 1. Buffer case: `producerIdx` is the index of the buffer in +/// `producer.getOutputBuffers()`. +/// 2. Tensor case: `producerIdx` is the index of the tensor in +/// `producer.getResults()`. +static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, + LinalgOp consumer, unsigned consumerIdx) { + Operation *shapeProducingOp = + consumer.getShapedOperand(consumerIdx).getDefiningOp(); + assert((isa(shapeProducingOp) || + isa(shapeProducingOp)) && + "SubviewOp or SubTensorOp expected"); + + // loopToOperandRangesMaps are permutations-only by construction: + // we can always identify a data dimension with a (at least one) loop + // dimension. + // TODO: extend this with range inference. + AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); + LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx + << ", producer map: " << producerMap << "\n"); unsigned nPar = producer.getNumParallelLoops(); unsigned nRed = producer.getNumReductionLoops(); unsigned nWin = producer.getNumWindowLoops(); SmallVector loopRanges(nPar + nRed + nWin); - for (auto fusedLoops : fusedLoopsAndRanges) - loopRanges[fusedLoops.first] = fusedLoops.second; + + // Iterate over dimensions identified by the producer map for `producerIdx`. + // This defines a subset of the loop ranges that we need to complete later. + auto loc = consumer.getLoc(); + for (auto en : llvm::enumerate(producerMap.getResults())) { + unsigned posInProducerLoop = en.value().cast().getPosition(); + loopRanges[posInProducerLoop] = + isa(shapeProducingOp) + ? cast(shapeProducingOp) + .getOrCreateRanges(b, loc)[en.index()] + : cast(shapeProducingOp) + .getOrCreateRanges(b, loc)[en.index()]; + } // Iterate over all dimensions. For the dimensions not identified by the // producer map for `producerIdx`, we need to explicitly compute the shape @@ -222,45 +250,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, } } - return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges); -} - -/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is -/// expected to be defined by a subview op or a subtensor op. -static Range getRangeFromOperandShape(OpBuilder &b, Location loc, - Value shapedOperand, unsigned dim) { - Operation *shapeProducingOp = shapedOperand.getDefiningOp(); - if (auto subViewOp = dyn_cast(shapeProducingOp)) - return subViewOp.getOrCreateRanges(b, loc)[dim]; - if (auto subTensorOp = dyn_cast(shapeProducingOp)) - return subTensorOp.getOrCreateRanges(b, loc)[dim]; - llvm_unreachable("SubviewOp or SubTensorOp expected"); -} - -/// Fuses the producer of `producerIdx` into the loop immediately enclosing -/// `consumer`. This is achieved by "recomputing" the `producer` at the time it -/// is needed just before the `consumer. -/// -/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are -/// 2 cases: -/// 1. Buffer case: `producerIdx` is the index of the buffer in -/// `producer.getOutputBuffers()`. -/// 2. Tensor case: `producerIdx` is the index of the tensor in -/// `producer.getResults()`. -static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, - LinalgOp consumer, unsigned consumerIdx) { - AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); - LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx - << ", producer map: " << producerMap << "\n"); - DenseMap fusedLoopsAndRanges; - Location loc = consumer.getLoc(); - Value shapedOperand = consumer.getShapedOperand(consumerIdx); - for (auto en : llvm::enumerate(producerMap.getResults())) { - unsigned posInProducerLoop = en.value().cast().getPosition(); - fusedLoopsAndRanges[posInProducerLoop] = - getRangeFromOperandShape(b, loc, shapedOperand, en.index()); - } - return fuse(b, producer, fusedLoopsAndRanges); + return cloneWithLoopRanges(b, loc, producer, loopRanges); } // Encode structural fusion safety preconditions. @@ -531,68 +521,9 @@ static AffineMap pruneReductionDimsFromMap(ArrayRef iteratorTypes, return getProjectedMap(map, projectedDims); } -/// Returns the mapping from iterations in the consumer that write to the same -/// location as the iterations in the producer. To do so use -/// - indexing map of the fused view in the consumer : consumerIndexMap -/// - indexing map of the fused view in the producer : producerIndexMap -/// consumerLoopToProducerLoop = -/// inverse(producerIndexMap).compose(consumerIndexMap) -static Optional getConsumerLoopToProducerLoopMap( - LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { - auto producer = cast(dependence.dependentOpView.op); - AffineMap producerIndexingMap = - producer.getIndexingMap(dependence.dependentOpView.operandIndex); - auto consumer = cast(dependence.indexingOpView.op); - AffineMap consumerIndexingMap = - consumer.getIndexingMap(dependence.indexingOpView.operandIndex); - - AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( - producer.iterator_types().getValue(), producerIndexingMap); - if (!prunedProducerIndexingMap.isPermutation()) - return None; - - if (consumerIndexingMap.getNumResults() != - prunedProducerIndexingMap.getNumResults()) - return None; - - LLVM_DEBUG({ - llvm::dbgs() << "\t producerMap : "; - producerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << " pruned : "; - prunedProducerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - llvm::dbgs() << "\t consumerMap : "; - consumerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - - AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); - if (!invProducerIndexMap) - return None; - - return invProducerIndexMap.compose(consumerIndexingMap); -} - -/// Given a projected permutation `map`, returns true if the map changes the -/// order in which the fused loop dimension appear. -static bool doesTransposeAccess(AffineMap map, - const std::set &fusableLoops) { - Optional lastFusableLoop; - for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) { - return expr.cast().getPosition(); - })) { - if (!fusableLoops.count(pos)) - continue; - if (!lastFusableLoop) { - lastFusableLoop = pos; - continue; - } - if (pos <= lastFusableLoop.getValue()) - return true; - lastFusableLoop = pos; - } - return false; -} +using FusableOpDependencesTy = llvm::MapVector< + Operation *, + SmallVector>; /// Returns the positions of the loop in `op` that can be tiled based on the /// operations that are to be fused with it. For example, in a @@ -607,7 +538,13 @@ static bool doesTransposeAccess(AffineMap map, /// 2. Of the parallel loops only some can be fused. Only those loops can be /// fused such where the fusable loops iteration space only touches one tile /// of the fused operation. This is because the producer (which is writing -/// the fused subview) has update semantics. +/// the fused subview) has update semantics. To compute this, +/// a. Find the mapping from iterations in the consumer that write to the +/// same location as the iterations in the producer. To do so use +/// - indexing map of the fused view in the consumer : consumerIndexMap +/// - indexing map of the fused view in the producer : producerIndexMap +/// consumerLoopToProducerLoop = +/// inverse(producerIndexMap).compose(consumerIndexMap) /// /// Since an inverse computation is needed, we need to consider the projection /// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops @@ -645,9 +582,8 @@ static bool doesTransposeAccess(AffineMap map, /// submap with only parallel loops = affine_map<(i, j) -> (j)> /// Fused dimensions : j static std::set -collectFusableLoops(ArrayRef ops, - const FusableOpDependencesTy &fusableDependences) { - assert(!ops.empty()); +collectTileAndFuseLoops(LinalgOp op, + const FusableOpDependencesTy &fusableDependences) { auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { return linalgOp.iterator_types() .getValue() @@ -658,245 +594,289 @@ collectFusableLoops(ArrayRef ops, .size(); }; - size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); - for (auto op : ops.drop_back()) { + LLVM_DEBUG({ + llvm::dbgs() << "Op : "; + op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n"; + }); + + size_t numOuterParallelLoops = getNumOuterParallelLoops(op); + for (auto dependence : fusableDependences) { + linalg::LinalgOp producer = cast(dependence.first); numOuterParallelLoops = - std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); + std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer)); } std::set fusableLoops; auto range = llvm::seq(0, numOuterParallelLoops); fusableLoops.insert(range.begin(), range.end()); + for (auto dependence : fusableDependences) { + LLVM_DEBUG({ + llvm::dbgs() << "\t fusable :"; + for (unsigned i : fusableLoops) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + linalg::LinalgOp producer = cast(dependence.first); - for (auto op : reverse(ops)) { - for (auto dependence : fusableDependences.lookup(op)) { - LLVM_DEBUG({ - llvm::dbgs() << "\t fusable :"; - for (unsigned i : fusableLoops) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); + assert(!dependence.second.empty() && + "unexpected producer but not dependences"); + AffineMap producerIndexingMap = producer.getIndexingMap( + dependence.second.front().dependentOpView.operandIndex); + AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( + producer.iterator_types().getValue(), producerIndexingMap); + if (!prunedProducerIndexingMap.isPermutation()) + return {}; - Optional consumerLoopToProducerLoop = - getConsumerLoopToProducerLoopMap(dependence); - if (!consumerLoopToProducerLoop) { - op.emitRemark("failed to get map from consumer loop to producer loop"); - return {}; - } - // todo: This condition is only an implementation limitation. When fusing - // the operation, if the accesses in the producer/consumer are transposes - // of each other, the loop bounds for the tiled producer can be - // manipulated accordingly. This requires some additional bookkeeping in - // the implementation of tile+fuse that is defered to later. - if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) { - op.emitRemark("unhandled fusion when fusion requires permutation"); - return {}; - } + AffineMap consumerIndexingMap = op.getIndexingMap( + dependence.second.front().indexingOpView.operandIndex); + if (consumerIndexingMap.getNumResults() != + prunedProducerIndexingMap.getNumResults()) + return {}; - std::set candidates; - for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) { - unsigned position = expr.cast().getPosition(); - if (fusableLoops.count(position)) - candidates.insert(position); - } - LLVM_DEBUG({ - llvm::dbgs() << "\t candidates :"; - for (unsigned i : candidates) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - if (candidates.empty()) - return {}; - std::swap(candidates, fusableLoops); + LLVM_DEBUG({ + llvm::dbgs() << "\t producerMap : "; + producerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << " pruned : "; + prunedProducerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + llvm::dbgs() << "\t consumerMap : "; + consumerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + AffineMap invProducerIndexMap = + inversePermutation(prunedProducerIndexingMap); + if (!invProducerIndexMap) + return {}; + + AffineMap consumerLoopToProducerLoop = + invProducerIndexMap.compose(consumerIndexingMap); + + LLVM_DEBUG({ + llvm::dbgs() << "\t consumerLoopToProducerLoop : "; + consumerLoopToProducerLoop.print(llvm::dbgs()); + }); + + std::set candidates; + for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) { + AffineDimExpr dimExpr = expr.dyn_cast(); + if (!dimExpr) + continue; + unsigned position = dimExpr.getPosition(); + if (fusableLoops.count(position)) + candidates.insert(position); } + LLVM_DEBUG({ + llvm::dbgs() << "\t candidates :"; + for (unsigned i : candidates) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + if (candidates.empty()) + return {}; + std::swap(candidates, fusableLoops); } return fusableLoops; } -/// Find all dependences that are fusable. -FusableOpDependencesTy mlir::linalg::findAllFusableDependences( - ArrayRef ops, const LinalgDependenceGraph &dependenceGraph) { +/// Find all dependences that are to be fusable. +static FusableOpDependencesTy +findAllFusableDependences(LinalgOp op, + const LinalgDependenceGraph &dependenceGraph, + const LinalgFusionOptions &fusionOptions) { FusableOpDependencesTy fusableDependences; // TODO: Currently fusion would not be legal if the fusable dependence is to // the same producer but different indexing map in the consumer. Fix this, but // in the meanwhile disallow such a fusion. DenseMap fusedProducerIndexingMap; - for (LinalgOp op : reverse(ops)) { - for (auto operandIndex : - llvm::seq(0, op.getNumInputsAndOutputBuffers())) { - Optional - fusableDependence = - findFusableProducer(op, operandIndex, dependenceGraph); - if (!fusableDependence) - continue; - LinalgOp producerOp = - cast(fusableDependence->dependentOpView.op); - // Do not fuse dependences that are to operations not in the same basic - // block. This avoid moving fused operations across loops that might - // themselves carry dependency making the fusion illegal. - if (producerOp.getOperation()->getBlock() != - op.getOperation()->getBlock()) { - op.emitRemark("unhandled fusion of ops in different basic blocks"); - return FusableOpDependencesTy{}; - } - // Make sure that the indexing map of the view used for fusion in the - // producer is a projected permutation. - unsigned producerIdx = fusableDependence->dependentOpView.operandIndex; - AffineMap producerMap = producerOp.getIndexingMap(producerIdx); - if (!producerMap.isProjectedPermutation()) { - op.emitRemark( - "unhandled non permutation indexing map for fused view in " - "producer for operand at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } - - unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex; - AffineMap consumerMap = op.getIndexingMap(consumerIdx); - if (!consumerMap.isProjectedPermutation()) { - op.emitRemark( - "unhandled case where indexing map for fused view in the consumer " - "is " - "not a projected permuration while fusing at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } - - // Check if the producer is already a fusion candidate. Cannot fuse this - // dependence if it has a different indexing map when used in the - // consumer. - if (fusedProducerIndexingMap.count(producerOp.getOperation()) && - fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { - op.emitRemark( - "unhandled fusion to the same producer but with different " - "indexing maps"); - return FusableOpDependencesTy{}; - } - fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; - - fusableDependences[producerOp.getOperation()].push_back( - *fusableDependence); + for (auto operandIndex : fusionOptions.indicesToFuse) { + auto fusableDependence = + findFusableProducer(op, operandIndex, dependenceGraph); + if (!fusableDependence) + return FusableOpDependencesTy{}; + LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); + // Do not fuse dependences that are to operations not in the same basic + // block. This avoid moving fused operations across loops that might + // themselves carry dependency making the fusion illegal. + if (producerOp.getOperation()->getBlock() != + op.getOperation()->getBlock()) { + op.emitRemark("unhandled fusion of ops in different basic blocks"); + return FusableOpDependencesTy{}; } + // Make sure that the indexing map of the view used for fusion in the + // producer is a projected permutation. + unsigned producerIdx = fusableDependence->dependentOpView.operandIndex; + AffineMap producerMap = producerOp.getIndexingMap(producerIdx); + if (!producerMap.isProjectedPermutation()) { + op.emitRemark("unhandled non permutation indexing map for fused view in " + "producer for operand at index ") + << operandIndex; + return FusableOpDependencesTy{}; + } + + unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex; + AffineMap consumerMap = op.getIndexingMap(consumerIdx); + if (!consumerMap.isProjectedPermutation()) { + op.emitRemark( + "unhandled case where indexing map for fused view in the consumer is " + "not a projected permutation while fusing at index ") + << operandIndex; + return FusableOpDependencesTy{}; + } + + // Check if the producer is already a fusion candidate. Cannot fuse this + // dependence if it has a different indexing map when used in the consumer. + if (fusedProducerIndexingMap.count(producerOp.getOperation()) && + fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { + op.emitRemark("unhandled fusion to the same producer but with different " + "indexing maps"); + return FusableOpDependencesTy{}; + } + fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; + + fusableDependences[producerOp.getOperation()].push_back(*fusableDependence); } return fusableDependences; } -/// Tile the fused loops in the root operation, by setting the tile sizes for -/// all other loops to zero (those will be tiled later). -static Optional tileRootOperation( - OpBuilder &builder, LinalgOp op, ArrayRef tileSizeVector, - const LinalgTilingOptions &options, const std::set &fusedLoops) { - SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); - auto zero = std_constant_index(0); - for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) - if (!fusedLoops.count(i)) - tileSizes[i] = zero; - LinalgTilingOptions tileFusedLoopsOptions = options; - tileFusedLoopsOptions.setTileSizes(tileSizes); - return tileLinalgOp(builder, op, tileFusedLoopsOptions); -} - -/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected -/// to be a tiled operation such that it is valid to fuse all operations in -/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of -/// `tiledOp`. -static SmallVector -fuseOperations(OpBuilder &builder, LinalgOp tiledOp, - ArrayRef fusionCandidates, - const FusableOpDependencesTy &fusableDependences, - const std::set &fusedLoops) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(tiledOp); - DenseMap fusedLoopsAndRanges; - for (unsigned loop : fusedLoops) { - ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop); - fusedLoopsAndRanges[loop] = getRangeFromOperandShape( - builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); - } - - SmallVector fusedOps(fusionCandidates.size()); - for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { - LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges); - fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; - builder.setInsertionPoint(fusedOp); - } - return fusedOps; +static bool isZero(Value v) { + if (auto cst = v.getDefiningOp()) + return cst.getValue() == 0; + return false; } template static Optional -tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, +tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions) { - if (ops.empty()) - return llvm::None; - LinalgOp rootOp = ops.back(); - for (auto op : enumerate(ops)) { - // TODO: Nothing in the fusion of sequence of ops is specific to - // buffers. This check can be removed after it is tested on tensors. - LinalgOp linalgOp = op.value(); - if (!linalgOp.hasBufferSemantics()) { - linalgOp.emitError("tile and fuse only tested for buffer operation"); - return llvm::None; - } - } - // TODO: Support interchange with tile + fuse. This might actually help do - // better fusion. + const LinalgTilingOptions &tilingOptions, + const LinalgFusionOptions &fusionOptions) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); + // Some of the tiling options might not be supportable with tile and fuse. + // TODO: Support interchange with tile + fuse. if (!tilingOptions.interchangeVector.empty()) { - rootOp.emitError("unable to handle tile and fuse with interchange"); + op.emitError("unable to handle tile and fuse with interchange"); return llvm::None; } - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(rootOp); - ScopedContext scope(builder, rootOp.getLoc()); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + ScopedContext scope(rewriter, op.getLoc()); // Find all the producers. FusableOpDependencesTy fusableDependences = - findAllFusableDependences(ops, dependenceGraph); + findAllFusableDependences(op, dependenceGraph, fusionOptions); if (fusableDependences.empty()) return llvm::None; + // Enforce the convention that "tiling by zero" skips tiling a particular + // dimension. This convention is significantly simpler to handle instead of + // adjusting affine maps to account for missing dimensions. + auto nLoops = op.getNumLoops(); + SmallVector tileSizeVector = + tilingOptions.tileSizeComputationFunction(rewriter, op); + if (tileSizeVector.size() < nLoops) { + auto zero = std_constant_index(0); + tileSizeVector.append(nLoops - tileSizeVector.size(), zero); + } + TiledAndFusedLinalgOps ret; + // Find the loops that can be tiled and fused. - ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences); + std::set tileFuseLoops = + collectTileAndFuseLoops(op, fusableDependences); // If there are no fusable dependences or there are no tile+fusable loops, // just return. - if (ret.fusedLoopDims.empty()) { + if (tileFuseLoops.empty()) { return llvm::None; } - // Tile the fused loops in the last operation in the list. - SmallVector tileSizeVector = - tilingOptions.tileSizeComputationFunction(builder, rootOp); - Optional tiledRootOp = tileRootOperation( - builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); - if (!tiledRootOp) { - rootOp.emitError("failed to tile the fused loops"); - return llvm::None; + // Get the tile sizes for the first and second tiling steps. For the first + // step the tile size are set to zero for the loops that arent + // fused. Similarly for the second step, the tile sizes are set to zero for + // the loops that are fused. For example, if for the following input + // + // ``` + // linalg.add ins(%a, %b) outs(%c) + // linalg.matmul ins(%d, %c) outs(%e) + // ``` + // + // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` + // respectively, and since only `j` can be tiled and fused. The tile sizes + // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable + // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile + // the tiled matmul generated by the first tiling step. + SmallVector tileAndFuseSizes, tileSizes; + for (auto tileSize : enumerate(tileSizeVector)) { + auto zero = std_constant_index(0); + if (tileFuseLoops.count(tileSize.index())) { + tileAndFuseSizes.push_back(tileSize.value()); + tileSizes.push_back(zero); + } else { + tileSizes.push_back(tileSize.value()); + tileAndFuseSizes.push_back(zero); + } + } + + // Tile for the loops that can be fused. + LinalgTilingOptions firstTilingOptions = tilingOptions; + firstTilingOptions.setTileSizes(tileAndFuseSizes); + Optional firstTiledOp = + tileLinalgOp(rewriter, op, firstTilingOptions); + if (!firstTiledOp) + return llvm::None; + ret.op = firstTiledOp->op; + ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); + + rewriter.setInsertionPoint(ret.op); + // Fuse the operands. + for (auto dependence : fusableDependences) { + LinalgOp producerOp = cast(dependence.first); + unsigned producerIdx = + dependence.second.front().dependentOpView.operandIndex; + unsigned consumerIdx = + dependence.second.front().indexingOpView.operandIndex; + LinalgOp fusedOp = fuse(rewriter, producerOp, + producerOp.getOutputIndex(producerIdx).getValue(), + ret.op, consumerIdx); + ret.fusedProducers.push_back(fusedOp); + ret.originalProducers.push_back(producerOp); + } + + if (!llvm::all_of(tileSizes, isZero)) { + // Tile the remaining loops of the root operation. + LinalgTilingOptions secondTilingOptions = tilingOptions; + // The distribution is done only for the tile+fused loops. + secondTilingOptions.distribution = llvm::None; + secondTilingOptions.setTileSizes(tileSizes); + Optional secondTiledOp = + tileLinalgOp(rewriter, ret.op, secondTilingOptions); + if (!secondTiledOp) + return llvm::None; + ret.unfusedLoops.assign(secondTiledOp->loops.begin(), + secondTiledOp->loops.end()); + rewriter.eraseOp(ret.op); + ret.op = secondTiledOp->op; } - ret.op = tiledRootOp->op; - ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); - // Fuse the other operations into the fused inter-tile loops produced above. - ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(), - fusableDependences, ret.fusedLoopDims); return ret; } Optional -mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, +mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions) { + const LinalgTilingOptions &tilingOptions, + const LinalgFusionOptions &fusionOptions) { switch (tilingOptions.loopType) { case LinalgTilingLoopType::Loops: - return tileAndFuseLinalgOpsImpl(builder, ops, dependenceGraph, - tilingOptions); + return tileAndFuseLinalgOpsImpl(rewriter, op, dependenceGraph, + tilingOptions, fusionOptions); case LinalgTilingLoopType::ParallelLoops: return tileAndFuseLinalgOpsImpl( - builder, ops, dependenceGraph, tilingOptions); + rewriter, op, dependenceGraph, tilingOptions, fusionOptions); default:; } return llvm::None; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index a824f6eb620f..e002336ed1c6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -166,6 +166,11 @@ struct LinalgOpInstancePromotionOptions { /// Alignment of promoted buffer. Optional alignment; }; + +struct PromotionInfo { + Value fullLocalView; + Value partialLocalView; +}; } // namespace LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( @@ -228,10 +233,10 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( // To account for general boundary effects, padding must be performed on the // boundary tiles. For now this is done with an unconditional `fill` op followed // by a partial `copy` op. -Optional mlir::linalg::promoteSubviewAsNewBuffer( - OpBuilder &b, Location loc, SubViewOp subView, - AllocBufferCallbackFn allocationFn, OperationFolder *folder) { - ScopedContext scopedContext(b, loc); +static Optional +promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView, + LinalgOpInstancePromotionOptions const &options, + OperationFolder *folder) { auto viewType = subView.getType(); auto rank = viewType.getRank(); SmallVector fullSizes, partialSizes; @@ -249,7 +254,8 @@ Optional mlir::linalg::promoteSubviewAsNewBuffer( SmallVector dynSizes(fullSizes.size(), -1); // If a callback is not specified, then use the default implementation for // allocating the promoted buffer. - Optional fullLocalView = allocationFn(b, subView, fullSizes, folder); + Optional fullLocalView = + options.allocationFn(b, subView, fullSizes, folder); if (!fullLocalView) return {}; auto zero = folded_std_constant_index(folder, 0); @@ -273,8 +279,8 @@ promoteSubViews(OpBuilder &b, Location loc, for (auto v : options.subViews) { SubViewOp subView = cast(v.second.getDefiningOp()); - Optional promotionInfo = promoteSubviewAsNewBuffer( - b, loc, subView, options.allocationFn, folder); + Optional promotionInfo = + promoteSubviewAsNewBuffer(b, loc, subView, options, folder); if (!promotionInfo) return {}; promotionInfoMap[v.first] = *promotionInfo; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index a855c07cb8d4..836cc28e0a47 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -165,69 +165,17 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( if (!linalgOp.hasBufferSemantics()) return failure(); - DenseSet producers; - producers.insert(linalgOp); - for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) { - if (!fusionOptions.indicesToFuse.count( - dependence.indexingOpView.operandIndex)) - continue; - if (isa(dependence.dependentOpView.op)) - producers.insert(dependence.dependentOpView.op); - } - - SmallVector fusionOps; - for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; - ++it) { - auto producerLinalgOp = dyn_cast(&(*it)); - if (producerLinalgOp && producers.count(producerLinalgOp)) - fusionOps.push_back(producerLinalgOp); - } - fusionOps.push_back(linalgOp); - - SmallVector tileSizes = - tilingOptions.tileSizeComputationFunction(rewriter, op); - LinalgTilingOptions instanceTilingOptions = tilingOptions; - instanceTilingOptions.setTileSizes(tileSizes); Optional tiledAndFusedOps = tileAndFuseLinalgOps( - rewriter, fusionOps, dependenceGraph, instanceTilingOptions); + rewriter, op, dependenceGraph, tilingOptions, fusionOptions); if (!tiledAndFusedOps) return failure(); - - // Tile the unfused loops; - SmallVector unfusedLoopTileSizes; - Value zero = rewriter.create(op->getLoc(), 0); - for (auto tileSize : enumerate(tileSizes)) { - if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) - unfusedLoopTileSizes.push_back(zero); - else - unfusedLoopTileSizes.push_back(tileSize.value()); - } - // Tile the loop only if there is a non-zero tile size. - if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) - unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); - if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { - if (auto cst = val.getDefiningOp()) - return cst.getValue() != 0; - return true; - })) { - LinalgTilingOptions unfusedTilingOptions = tilingOptions; - unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); - Optional unfusedTiledOp = - tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); - if (!unfusedTiledOp) - return failure(); - rewriter.eraseOp(tiledAndFusedOps->op); - tiledAndFusedOps->op = unfusedTiledOp->op; - } - marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); for (auto fusedOp : tiledAndFusedOps->fusedProducers) { fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); } - for (auto origProducerOp : ArrayRef(fusionOps).drop_back()) { + for (auto origProducerOp : tiledAndFusedOps->originalProducers) originalOpMarker.replaceLinalgMarker(rewriter, origProducerOp.getOperation()); - } rewriter.updateRootInPlace( op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); return success(); diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir index fa471811ef4e..2ddc66651db2 100644 --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -47,9 +47,7 @@ module { // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]] // CHECK: %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]] // CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] -// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]] -// CHECK: linalg.fill(%[[SV3_2]], %[[CST]]) +// CHECK: linalg.fill(%[[SV3]], %[[CST]]) // CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" // CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { // CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]] @@ -111,12 +109,9 @@ module { // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]] // CHECK: %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]] // CHECK-SAME: [%[[M]], %[[TILE_N_2]]] -// CHECK: %[[K_2:.+]] = dim %[[ARG1]], %[[C0]] // CHECK: %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]] -// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] -// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]] -// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] -// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) +// CHECK-SAME: [%[[K]], %[[TILE_N]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) // CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer" // CHECK-NOT: linalg.fill // CHECK-DAG: %[[M_2:.+]] = dim %[[ARG0]], %[[C0]] @@ -191,16 +186,11 @@ module { // CHECK: %[[N:.+]] = dim %[[ARG3]], %[[C1]] // CHECK: %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] -// CHECK: %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N]]] -// CHECK: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] -// CHECK: %[[SV3_2:.+]] = subview %[[ARG1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] -// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) +// CHECK-SAME: [%[[TILE_M]], %[[K]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) // CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer" -// CHECK: linalg.fill(%[[SV2_2]], %[[CST]]) +// CHECK: linalg.fill(%[[SV2]], %[[CST]]) // CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer" // CHECK-DAG: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]] // CHECK: scf.parallel (%[[IV1:.+]]) = @@ -271,18 +261,15 @@ module { // CHECK: %[[N:.+]] = dim %[[ARG4]], %[[C1]] // CHECK: %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] -// CHECK: %[[K2_2:.+]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[K1:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M]], %[[K1]]] -// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]] -// CHECK: %[[SV1_2:.+]] = subview %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]] +// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]] // CHECK: linalg.matmul // CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" // CHECK-SAME: ins(%[[SV3]], %[[SV4]] // CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV1_2]] : memref) +// CHECK-SAME: outs(%[[SV1]] : memref) // CHECK-DAG: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]] // CHECK: scf.parallel (%[[IV1:.+]]) = // CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) { diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir deleted file mode 100644 index a02c878ef341..000000000000 --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ /dev/null @@ -1,133 +0,0 @@ -// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s - -module { - func @three_op_fusion(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3 : memref) { - %cst = constant 0.000000e+00 : f32 - %c0 = constant 0 : index - %c1 = constant 1 : index - %d0 = dim %arg0, %c0 : memref - %d1 = dim %arg1, %c1 : memref - %0 = alloc(%d0, %d1) : memref - linalg.fill(%0, %cst) : memref, f32 - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%0 : memref) - linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%0, %arg2 : memref, memref) - outs(%arg3 : memref) { - ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) : - %5 = addf %arg4, %arg5 : f32 - linalg.yield %5 : f32 - } - return - } -} - -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK: func @three_op_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} { -// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]] -// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]] -// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]] -// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]] -// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}}) -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_TEMP]] : memref) -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ARG3]] : memref) -// CHECK: scf.yield -// CHECK: } - -// ----- - -module { - func @sequence_of_matmul(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3: memref, - %arg4: memref) { - %cst = constant 0.000000e+00 : f32 - %c0 = constant 0 : index - %c1 = constant 1 : index - %m = dim %arg0, %c0 : memref - %n1 = dim %arg1, %c1 : memref - %n2 = dim %arg2, %c1 : memref - %n3 = dim %arg3, %c1 : memref - %0 = alloc(%m, %n1) : memref - %1 = alloc(%m, %n2) : memref - linalg.fill(%0, %cst) : memref, f32 - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%0 : memref) - linalg.fill(%1, %cst) : memref, f32 - linalg.matmul ins(%0, %arg2 : memref, memref) - outs(%1 : memref) - linalg.fill(%arg4, %cst) : memref, f32 - linalg.matmul ins(%1, %arg3 : memref, memref) - outs(%arg4 : memref) - return - } -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK: func @sequence_of_matmul -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C16:.+]] = constant 16 : index -// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]] -// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]] -// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]]) -// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]]) -// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]]) -// CHECK-SAME: step (%[[C16]]) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N2]]] -// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]] -// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] -// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]] -// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] -// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N3]]] -// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N1]]] -// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]] -// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]] -// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]] -// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] -// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}}) -// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref) -// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}}) -// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref) -// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}}) -// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ARG4]] : memref) -// CHECK: scf.yield -// CHECK: } - diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp index 5289b2d1055f..eb9e3a533138 100644 --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -197,44 +197,6 @@ struct TestLinalgGreedyFusion } } }; - -/// Pass to test tile and fuse of sequence of operations. Intended only for -/// testing. -struct TestLinalgTileAndFuseSequencePass - : public PassWrapper { - TestLinalgTileAndFuseSequencePass() = default; - TestLinalgTileAndFuseSequencePass( - const TestLinalgTileAndFuseSequencePass &pass){}; - - ListOption tileSizes{ - *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), - llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnFunction() override { - FuncOp funcOp = getOperation(); - auto &blocks = funcOp.getBody().getBlocks(); - if (!llvm::hasSingleElement(blocks)) { - return; - } - SmallVector linalgOps = - llvm::to_vector<2>(blocks.front().getOps()); - Aliases aliases; - LinalgDependenceGraph dependenceGraph(aliases, linalgOps); - OpBuilder builder(funcOp.getContext()); - Optional tileAndFuseOps = tileAndFuseLinalgOps( - builder, linalgOps, dependenceGraph, - LinalgTilingOptions().setTileSizes(tileSizes).setLoopType( - LinalgTilingLoopType::ParallelLoops)); - if (!tileAndFuseOps) - return signalPassFailure(); - for (auto op : linalgOps) - op.erase(); - } -}; } // namespace namespace mlir { @@ -249,12 +211,5 @@ void registerTestLinalgGreedyFusion() { "test-linalg-greedy-fusion", "Test Linalg fusion by applying a greedy test transformation."); } -void registerTestLinalgTileAndFuseSequencePass() { - PassRegistration - testTileAndFuseSequencePass( - "test-linalg-tile-and-fuse", - "Test Linalg tiling and fusion of a sequence of Linalg operations."); -} - } // namespace test } // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index a0e36cf82534..4771b11b20e4 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -74,7 +74,6 @@ void registerTestLinalgCodegenStrategy(); void registerTestLinalgFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); -void registerTestLinalgTileAndFuseSequencePass(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); void registerTestLoopFusion(); @@ -142,7 +141,6 @@ void registerTestPasses() { test::registerTestLinalgFusionTransforms(); test::registerTestLinalgGreedyFusion(); test::registerTestLinalgHoisting(); - test::registerTestLinalgTileAndFuseSequencePass(); test::registerTestLinalgTransforms(); test::registerTestLivenessPass(); test::registerTestLoopFusion();