[mlir][sparse] split post-sparsification-rewriting into two passes. (#70727)

This commit is contained in:
Peiming Liu
2023-10-30 15:22:21 -07:00
committed by GitHub
parent b1c59b516c
commit f82bee1367
17 changed files with 92 additions and 55 deletions

View File

@@ -114,17 +114,23 @@ void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
std::unique_ptr<Pass> createStageSparseOperationsPass();
//===----------------------------------------------------------------------===//
// The PostSparsificationRewriting pass.
// The LowerSparseOpsToForeach pass.
//===----------------------------------------------------------------------===//
void populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableRT, bool enableForeach,
bool enableConvert);
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
bool enableRT, bool enableConvert);
std::unique_ptr<Pass> createPostSparsificationRewritePass();
std::unique_ptr<Pass>
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
bool enableConvert = true);
std::unique_ptr<Pass> createLowerSparseOpsToForeachPass();
std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(bool enableRT,
bool enableConvert);
//===----------------------------------------------------------------------===//
// The LowerForeachToSCF pass.
//===----------------------------------------------------------------------===//
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
std::unique_ptr<Pass> createLowerForeachToSCFPass();
//===----------------------------------------------------------------------===//
// The SparseTensorConversion pass.

View File

@@ -167,13 +167,12 @@ def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
];
}
def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
def LowerSparseOpsToForeach : Pass<"lower-sparse-ops-to-foreach", "ModuleOp"> {
let summary = "Applies sparse tensor rewriting rules after sparsification";
let description = [{
A pass that applies rewriting rules to sparse tensor operations after
running the actual sparsification pass.
A pass that lowers high-level sparse operations to sparse_tensor.foreach.
}];
let constructor = "mlir::createPostSparsificationRewritePass()";
let constructor = "mlir::createLowerSparseOpsToForeachPass()";
let dependentDialects = [
"affine::AffineDialect",
"arith::ArithDialect",
@@ -186,13 +185,25 @@ def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp">
let options = [
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
"true", "Enable runtime library for manipulating sparse tensors">,
Option<"enableForeach", "enable-foreach", "bool",
"true", "Enable rewriting rules for the foreach operator">,
Option<"enableConvert", "enable-convert", "bool",
"true", "Enable rewriting rules for the convert operator">,
];
}
def LowerForeachToSCF : Pass<"lower-sparse-foreach-to-scf", "func::FuncOp"> {
let summary = "Decompose a complex sparse operation into multiple stages";
let description = [{
A pass that lowers sparse_tensor.foreach operation to scf dialect.
}];
let constructor = "mlir::createLowerForeachToSCFPass()";
let dependentDialects = [
"memref::MemRefDialect",
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
];
}
def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
let summary = "Convert sparse tensors and primitives to library calls";
let description = [{

View File

@@ -25,7 +25,8 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
@@ -120,23 +121,34 @@ struct StageSparseOperationsPass
}
};
struct PostSparsificationRewritePass
: public impl::PostSparsificationRewriteBase<
PostSparsificationRewritePass> {
PostSparsificationRewritePass() = default;
PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
struct LowerSparseOpsToForeachPass
: public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
LowerSparseOpsToForeachPass() = default;
LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
default;
PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) {
LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
enableRuntimeLibrary = enableRT;
enableForeach = foreach;
enableConvert = convert;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populatePostSparsificationRewriting(patterns, enableRuntimeLibrary,
enableForeach, enableConvert);
populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
enableConvert);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct LowerForeachToSCFPass
: public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
LowerForeachToSCFPass() = default;
LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateLowerForeachToSCFPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -399,15 +411,17 @@ std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
return std::make_unique<StageSparseOperationsPass>();
}
std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
return std::make_unique<PostSparsificationRewritePass>();
std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
return std::make_unique<LowerSparseOpsToForeachPass>();
}
std::unique_ptr<Pass>
mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach,
bool enableConvert) {
return std::make_unique<PostSparsificationRewritePass>(
enableRT, enableForeach, enableConvert);
mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
}
std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
return std::make_unique<LowerForeachToSCFPass>();
}
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {

View File

@@ -1303,10 +1303,9 @@ void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
}
void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableRT,
bool enableForeach,
bool enableConvert) {
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
bool enableRT,
bool enableConvert) {
patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>,
@@ -1314,10 +1313,13 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
if (enableConvert)
patterns.add<DirectConvertRewriter>(patterns.getContext());
if (!enableRT)
patterns.add<NewRewriter>(patterns.getContext());
}
void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
patterns.add<ForeachRewriter>(patterns.getContext());
}

View File

@@ -141,7 +141,10 @@ public:
OpPassManager pm("builtin.module");
pm.addPass(createSparsificationPass(sparsificationOptions));
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
pm.addPass(createPostSparsificationRewritePass(enableRuntimeLibrary));
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
/*enableConvert=*/true));
// TODO: DemapPass here!
pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
if (vectorLength > 0) {
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addPass(createSparseVectorizationPass(

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed)

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed)

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed)

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
#SparseVector64 = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed),

View File

@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" | \
// RUN: FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
// RUN: --lower-sparse-foreach-to-scf | FileCheck %s
#CSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : dense, d1 : compressed)

View File

@@ -1,6 +1,6 @@
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" --lower-sparse-foreach-to-scf \
// RUN: | FileCheck %s
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" --lower-sparse-foreach-to-scf \
// RUN: | FileCheck %s

View File

@@ -4,7 +4,8 @@
// RUN: FileCheck %s --check-prefix=CHECK-SPARSE
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --linalg-fuse-elementwise-ops \
// RUN: --sparsification --post-sparsification-rewrite \
// RUN: --sparsification --lower-sparse-ops-to-foreach \
// RUN: --lower-sparse-foreach-to-scf \
// RUN: --sparse-tensor-conversion --cse | \
// RUN: FileCheck %s --check-prefix=CHECK-CONVERT

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" --canonicalize | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-foreach-to-scf --canonicalize | FileCheck %s
// CHECK-LABEL: func.func @sparse_foreach_constant
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --canonicalize --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
#COO = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton),

View File

@@ -1,8 +1,8 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
// RUN: --cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
// RUN: --cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" \
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>

View File

@@ -1,5 +1,5 @@
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
// RUN: --cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>