mirror of
https://github.com/intel/llvm.git
synced 2026-02-06 15:18:53 +08:00
[mlir][sparse] split post-sparsification-rewriting into two passes. (#70727)
This commit is contained in:
@@ -114,17 +114,23 @@ void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
|
||||
std::unique_ptr<Pass> createStageSparseOperationsPass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The PostSparsificationRewriting pass.
|
||||
// The LowerSparseOpsToForeach pass.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void populatePostSparsificationRewriting(RewritePatternSet &patterns,
|
||||
bool enableRT, bool enableForeach,
|
||||
bool enableConvert);
|
||||
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
|
||||
bool enableRT, bool enableConvert);
|
||||
|
||||
std::unique_ptr<Pass> createPostSparsificationRewritePass();
|
||||
std::unique_ptr<Pass>
|
||||
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
|
||||
bool enableConvert = true);
|
||||
std::unique_ptr<Pass> createLowerSparseOpsToForeachPass();
|
||||
std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(bool enableRT,
|
||||
bool enableConvert);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The LowerForeachToSCF pass.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<Pass> createLowerForeachToSCFPass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The SparseTensorConversion pass.
|
||||
|
||||
@@ -167,13 +167,12 @@ def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
|
||||
];
|
||||
}
|
||||
|
||||
def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
|
||||
def LowerSparseOpsToForeach : Pass<"lower-sparse-ops-to-foreach", "ModuleOp"> {
|
||||
let summary = "Applies sparse tensor rewriting rules after sparsification";
|
||||
let description = [{
|
||||
A pass that applies rewriting rules to sparse tensor operations after
|
||||
running the actual sparsification pass.
|
||||
A pass that lowers high-level sparse operations to sparse_tensor.foreach.
|
||||
}];
|
||||
let constructor = "mlir::createPostSparsificationRewritePass()";
|
||||
let constructor = "mlir::createLowerSparseOpsToForeachPass()";
|
||||
let dependentDialects = [
|
||||
"affine::AffineDialect",
|
||||
"arith::ArithDialect",
|
||||
@@ -186,13 +185,25 @@ def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp">
|
||||
let options = [
|
||||
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
|
||||
"true", "Enable runtime library for manipulating sparse tensors">,
|
||||
Option<"enableForeach", "enable-foreach", "bool",
|
||||
"true", "Enable rewriting rules for the foreach operator">,
|
||||
Option<"enableConvert", "enable-convert", "bool",
|
||||
"true", "Enable rewriting rules for the convert operator">,
|
||||
];
|
||||
}
|
||||
|
||||
def LowerForeachToSCF : Pass<"lower-sparse-foreach-to-scf", "func::FuncOp"> {
|
||||
let summary = "Decompose a complex sparse operation into multiple stages";
|
||||
let description = [{
|
||||
A pass that lowers sparse_tensor.foreach operation to scf dialect.
|
||||
}];
|
||||
let constructor = "mlir::createLowerForeachToSCFPass()";
|
||||
let dependentDialects = [
|
||||
"memref::MemRefDialect",
|
||||
"scf::SCFDialect",
|
||||
"sparse_tensor::SparseTensorDialect",
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
|
||||
let summary = "Convert sparse tensors and primitives to library calls";
|
||||
let description = [{
|
||||
|
||||
@@ -25,7 +25,8 @@ namespace mlir {
|
||||
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
|
||||
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
|
||||
#define GEN_PASS_DEF_SPARSIFICATIONPASS
|
||||
#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
|
||||
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
|
||||
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
|
||||
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
|
||||
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
|
||||
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
|
||||
@@ -120,23 +121,34 @@ struct StageSparseOperationsPass
|
||||
}
|
||||
};
|
||||
|
||||
struct PostSparsificationRewritePass
|
||||
: public impl::PostSparsificationRewriteBase<
|
||||
PostSparsificationRewritePass> {
|
||||
PostSparsificationRewritePass() = default;
|
||||
PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
|
||||
struct LowerSparseOpsToForeachPass
|
||||
: public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
|
||||
LowerSparseOpsToForeachPass() = default;
|
||||
LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
|
||||
default;
|
||||
PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) {
|
||||
LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
|
||||
enableRuntimeLibrary = enableRT;
|
||||
enableForeach = foreach;
|
||||
enableConvert = convert;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populatePostSparsificationRewriting(patterns, enableRuntimeLibrary,
|
||||
enableForeach, enableConvert);
|
||||
populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
|
||||
enableConvert);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct LowerForeachToSCFPass
|
||||
: public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
|
||||
LowerForeachToSCFPass() = default;
|
||||
LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateLowerForeachToSCFPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
@@ -399,15 +411,17 @@ std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
|
||||
return std::make_unique<StageSparseOperationsPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
|
||||
return std::make_unique<PostSparsificationRewritePass>();
|
||||
std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
|
||||
return std::make_unique<LowerSparseOpsToForeachPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach,
|
||||
bool enableConvert) {
|
||||
return std::make_unique<PostSparsificationRewritePass>(
|
||||
enableRT, enableForeach, enableConvert);
|
||||
mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
|
||||
return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
|
||||
return std::make_unique<LowerForeachToSCFPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
|
||||
|
||||
@@ -1303,10 +1303,9 @@ void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
|
||||
GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
|
||||
bool enableRT,
|
||||
bool enableForeach,
|
||||
bool enableConvert) {
|
||||
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
|
||||
bool enableRT,
|
||||
bool enableConvert) {
|
||||
patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
|
||||
ReshapeRewriter<tensor::ExpandShapeOp>,
|
||||
ReshapeRewriter<tensor::CollapseShapeOp>,
|
||||
@@ -1314,10 +1313,13 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
|
||||
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
|
||||
SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
|
||||
patterns.getContext());
|
||||
if (enableForeach)
|
||||
patterns.add<ForeachRewriter>(patterns.getContext());
|
||||
|
||||
if (enableConvert)
|
||||
patterns.add<DirectConvertRewriter>(patterns.getContext());
|
||||
if (!enableRT)
|
||||
patterns.add<NewRewriter>(patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<ForeachRewriter>(patterns.getContext());
|
||||
}
|
||||
|
||||
@@ -141,7 +141,10 @@ public:
|
||||
OpPassManager pm("builtin.module");
|
||||
pm.addPass(createSparsificationPass(sparsificationOptions));
|
||||
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
|
||||
pm.addPass(createPostSparsificationRewritePass(enableRuntimeLibrary));
|
||||
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
|
||||
/*enableConvert=*/true));
|
||||
// TODO: DemapPass here!
|
||||
pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
|
||||
if (vectorLength > 0) {
|
||||
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
|
||||
pm.addPass(createSparseVectorizationPass(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
|
||||
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
|
||||
|
||||
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
|
||||
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{
|
||||
map = (d0) -> (d0 : compressed)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
|
||||
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{
|
||||
map = (d0) -> (d0 : compressed)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
|
||||
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{
|
||||
map = (d0) -> (d0 : compressed)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
|
||||
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
|
||||
|
||||
#SparseVector64 = #sparse_tensor.encoding<{
|
||||
map = (d0) -> (d0 : compressed),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" | \
|
||||
// RUN: FileCheck %s
|
||||
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: --lower-sparse-foreach-to-scf | FileCheck %s
|
||||
|
||||
#CSR = #sparse_tensor.encoding<{
|
||||
map = (d0, d1) -> (d0 : dense, d1 : compressed)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" --lower-sparse-foreach-to-scf \
|
||||
// RUN: | FileCheck %s
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
|
||||
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" --lower-sparse-foreach-to-scf \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
// RUN: FileCheck %s --check-prefix=CHECK-SPARSE
|
||||
// RUN: mlir-opt %s --linalg-generalize-named-ops \
|
||||
// RUN: --linalg-fuse-elementwise-ops \
|
||||
// RUN: --sparsification --post-sparsification-rewrite \
|
||||
// RUN: --sparsification --lower-sparse-ops-to-foreach \
|
||||
// RUN: --lower-sparse-foreach-to-scf \
|
||||
// RUN: --sparse-tensor-conversion --cse | \
|
||||
// RUN: FileCheck %s --check-prefix=CHECK-CONVERT
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" --canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s --lower-sparse-foreach-to-scf --canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @sparse_foreach_constant
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s --canonicalize --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
|
||||
|
||||
#COO = #sparse_tensor.encoding<{
|
||||
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton),
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
|
||||
// RUN: --cse --canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: --cse --canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" \
|
||||
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
|
||||
#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: --cse --canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
|
||||
|
||||
#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user