mirror of
https://github.com/intel/llvm.git
synced 2026-02-01 17:07:36 +08:00
[MLIR][Linalg] Lower linalg.tiled_loop in a separate pass
Add dedicated pass `convert-linalg-tiled-loops-to-scf` to lower `linalg.tiled_loop`s. Differential Revision: https://reviews.llvm.org/D101768
This commit is contained in:
@@ -36,6 +36,10 @@ std::unique_ptr<OperationPass<FuncOp>>
|
||||
createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca);
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass();
|
||||
|
||||
/// Create a pass to convert Linalg tiled loops to `scf.for` and `scf.parallel`
|
||||
/// loops and memref.load/memref.store accesses.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgTiledLoopsToSCFPass();
|
||||
|
||||
/// Create a pass to convert Linalg operations to scf.for loops and
|
||||
/// memref.load/memref.store accesses.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToLoopsPass();
|
||||
|
||||
@@ -58,6 +58,17 @@ def LinalgFoldReshapeOpsByLinearization :
|
||||
let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def LinalgLowerTiledLoopsToSCF
|
||||
: FunctionPass<"convert-linalg-tiled-loops-to-scf"> {
|
||||
let summary = "Lower linalg tiled loops to SCF loops and parallel loops";
|
||||
let constructor = "mlir::createConvertLinalgTiledLoopsToSCFPass()";
|
||||
let dependentDialects = [
|
||||
"linalg::LinalgDialect",
|
||||
"scf::SCFDialect",
|
||||
"AffineDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
|
||||
let summary = "Lower the operations from the linalg dialect into affine "
|
||||
"loops";
|
||||
@@ -76,16 +87,6 @@ def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgBufferize : Pass<"linalg-bufferize", "FuncOp"> {
|
||||
let summary = "Bufferize the linalg dialect";
|
||||
let constructor = "mlir::createLinalgBufferizePass()";
|
||||
let dependentDialects = [
|
||||
"linalg::LinalgDialect",
|
||||
"AffineDialect",
|
||||
"memref::MemRefDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgLowerToParallelLoops
|
||||
: FunctionPass<"convert-linalg-to-parallel-loops"> {
|
||||
let summary = "Lower the operations from the linalg dialect into parallel "
|
||||
@@ -99,6 +100,16 @@ def LinalgLowerToParallelLoops
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgBufferize : Pass<"linalg-bufferize", "FuncOp"> {
|
||||
let summary = "Bufferize the linalg dialect";
|
||||
let constructor = "mlir::createLinalgBufferizePass()";
|
||||
let dependentDialects = [
|
||||
"linalg::LinalgDialect",
|
||||
"AffineDialect",
|
||||
"memref::MemRefDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
|
||||
let summary = "Promote subview ops to local buffers";
|
||||
let constructor = "mlir::createLinalgPromotionPass()";
|
||||
|
||||
@@ -555,7 +555,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct TiledLoopPattern : public OpRewritePattern<TiledLoopOp> {
|
||||
struct TiledLoopToSCFPattern : public OpRewritePattern<TiledLoopOp> {
|
||||
using OpRewritePattern<TiledLoopOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TiledLoopOp tiledLoop,
|
||||
@@ -597,7 +597,7 @@ template <typename LoopType>
|
||||
static void lowerLinalgToLoopsImpl(FuncOp funcOp) {
|
||||
MLIRContext *context = funcOp.getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<LinalgRewritePattern<LoopType>, TiledLoopPattern>(context);
|
||||
patterns.add<LinalgRewritePattern<LoopType>>(context);
|
||||
memref::DimOp::getCanonicalizationPatterns(patterns, context);
|
||||
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||
patterns.add<FoldAffineOp>(context);
|
||||
@@ -668,8 +668,23 @@ struct LowerToParallelLoops
|
||||
lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction());
|
||||
}
|
||||
};
|
||||
|
||||
struct LowerTiledLoopsToSCF
|
||||
: public LinalgLowerTiledLoopsToSCFBase<LowerTiledLoopsToSCF> {
|
||||
void runOnFunction() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<TiledLoopToSCFPattern>(context);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createConvertLinalgTiledLoopsToSCFPass() {
|
||||
return std::make_unique<LowerTiledLoopsToSCF>();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertLinalgToLoopsPass() {
|
||||
return std::make_unique<LowerToLoops>();
|
||||
}
|
||||
|
||||
@@ -1522,78 +1522,3 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
|
||||
|
||||
|
||||
#map0 = affine_map<(d0) -> (24, -d0 + 192)>
|
||||
#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
|
||||
#map2 = affine_map<(d0) -> (16, -d0 + 192)>
|
||||
|
||||
func @tiled_loop_to_parallel(%A: memref<192x192xf32>,
|
||||
%B: memref<192x192xf32>,
|
||||
%C: memref<192x192xf32>) {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%c24 = constant 24 : index
|
||||
%c16 = constant 16 : index
|
||||
%c0 = constant 0 : index
|
||||
%c192 = constant 192 : index
|
||||
|
||||
linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
|
||||
ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
|
||||
outs (%C_ = %C: memref<192x192xf32>) {
|
||||
%0 = affine.min #map0(%i)
|
||||
%1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1]
|
||||
: memref<192x192xf32> to memref<?x192xf32, #map1>
|
||||
%2 = affine.min #map2(%j)
|
||||
%3 = memref.subview %B_[0, %j] [192, %2] [1, 1]
|
||||
: memref<192x192xf32> to memref<192x?xf32, #map1>
|
||||
%4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1]
|
||||
: memref<192x192xf32> to memref<?x?xf32, #map1>
|
||||
linalg.fill(%4, %cst) : memref<?x?xf32, #map1>, f32
|
||||
linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
|
||||
memref<192x?xf32, #map1>)
|
||||
outs(%4 : memref<?x?xf32, #map1>)
|
||||
linalg.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECKLOOP-LABEL: @tiled_loop_to_parallel
|
||||
// CHECKLOOP-SAME: %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>,
|
||||
// CHECKLOOP-SAME: %[[C:.*]]: memref<192x192xf32>) {
|
||||
// CHECKLOOP: %[[C24:.*]] = constant 24 : index
|
||||
// CHECKLOOP: %[[C16:.*]] = constant 16 : index
|
||||
// CHECKLOOP: %[[C192:.*]] = constant 192 : index
|
||||
// CHECKLOOP: %[[C0:.*]] = constant 0 : index
|
||||
// CHECKLOOP: scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] {
|
||||
// CHECKLOOP: scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] {
|
||||
// CHECKLOOP: %[[A_sub:.*]] = memref.subview %[[A]][%[[I]]
|
||||
// CHECKLOOP: %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]]
|
||||
// CHECKLOOP: %[[C_sub:.*]] = memref.subview %[[C]][%[[I]]
|
||||
|
||||
|
||||
func @tiled_loop_to_for(%A: memref<192x192xf32>,
|
||||
%B: memref<192x192xf32>,
|
||||
%C: memref<f32>) {
|
||||
%c24 = constant 24 : index
|
||||
%c16 = constant 16 : index
|
||||
%c0 = constant 0 : index
|
||||
%c192 = constant 192 : index
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
|
||||
linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
|
||||
ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
|
||||
outs (%C_ = %C: memref<f32>)
|
||||
iterators["reduction", "reduction"] {
|
||||
linalg.fill(%A_, %cst) : memref<192x192xf32>, f32
|
||||
linalg.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECKLOOP-LABEL: @tiled_loop_to_for
|
||||
// CHECKLOOP: %[[C24:.*]] = constant 24 : index
|
||||
// CHECKLOOP: %[[C16:.*]] = constant 16 : index
|
||||
// CHECKLOOP: %[[C192:.*]] = constant 192 : index
|
||||
// CHECKLOOP: %[[C0:.*]] = constant 0 : index
|
||||
// CHECKLOOP: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
|
||||
// CHECKLOOP: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]
|
||||
|
||||
79
mlir/test/Dialect/Linalg/tiled-loops.mlir
Normal file
79
mlir/test/Dialect/Linalg/tiled-loops.mlir
Normal file
@@ -0,0 +1,79 @@
|
||||
// RUN: mlir-opt %s -convert-linalg-tiled-loops-to-scf | FileCheck %s
|
||||
|
||||
|
||||
#map0 = affine_map<(d0) -> (24, -d0 + 192)>
|
||||
#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
|
||||
#map2 = affine_map<(d0) -> (16, -d0 + 192)>
|
||||
|
||||
func @tiled_loop(%A: memref<192x192xf32>,
|
||||
%B: memref<192x192xf32>,
|
||||
%C: memref<192x192xf32>) {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%c24 = constant 24 : index
|
||||
%c16 = constant 16 : index
|
||||
%c0 = constant 0 : index
|
||||
%c192 = constant 192 : index
|
||||
|
||||
linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
|
||||
ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
|
||||
outs (%C_ = %C: memref<192x192xf32>) {
|
||||
%0 = affine.min #map0(%i)
|
||||
%1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1]
|
||||
: memref<192x192xf32> to memref<?x192xf32, #map1>
|
||||
%2 = affine.min #map2(%j)
|
||||
%3 = memref.subview %B_[0, %j] [192, %2] [1, 1]
|
||||
: memref<192x192xf32> to memref<192x?xf32, #map1>
|
||||
%4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1]
|
||||
: memref<192x192xf32> to memref<?x?xf32, #map1>
|
||||
linalg.fill(%4, %cst) : memref<?x?xf32, #map1>, f32
|
||||
linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
|
||||
memref<192x?xf32, #map1>)
|
||||
outs(%4 : memref<?x?xf32, #map1>)
|
||||
linalg.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tiled_loop
|
||||
// CHECK-SAME: %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>,
|
||||
// CHECK-SAME: %[[C:.*]]: memref<192x192xf32>) {
|
||||
// CHECK: %[[C24:.*]] = constant 24 : index
|
||||
// CHECK: %[[C16:.*]] = constant 16 : index
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C192:.*]] = constant 192 : index
|
||||
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] {
|
||||
// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] {
|
||||
// CHECK: %[[A_sub:.*]] = memref.subview %[[A]][%[[I]]
|
||||
// CHECK: %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]]
|
||||
// CHECK: %[[C_sub:.*]] = memref.subview %[[C]][%[[I]]
|
||||
// CHECK: linalg.fill
|
||||
// CHECK: linalg.matmul
|
||||
|
||||
|
||||
func @tiled_loop_reduction(%A: memref<192x192xf32>,
|
||||
%B: memref<192x192xf32>,
|
||||
%C: memref<f32>) {
|
||||
%c24 = constant 24 : index
|
||||
%c16 = constant 16 : index
|
||||
%c0 = constant 0 : index
|
||||
%c192 = constant 192 : index
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
|
||||
linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
|
||||
ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
|
||||
outs (%C_ = %C: memref<f32>)
|
||||
iterators["reduction", "reduction"] {
|
||||
linalg.fill(%A_, %cst) : memref<192x192xf32>, f32
|
||||
linalg.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @tiled_loop_reduction
|
||||
// CHECK: %[[C24:.*]] = constant 24 : index
|
||||
// CHECK: %[[C16:.*]] = constant 16 : index
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C192:.*]] = constant 192 : index
|
||||
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
|
||||
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]
|
||||
// CHECK: linalg.fill
|
||||
Reference in New Issue
Block a user