mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 23:45:25 +08:00
[mlir][sparse] Split SparseTensorRewrite into PreSparsificationRewrite and PostSparsificationRewrite.
Reviewed By: aartbik, wrengr Differential Revision: https://reviews.llvm.org/D138153
This commit is contained in:
@@ -138,16 +138,25 @@ std::unique_ptr<Pass>
|
||||
createSparseTensorCodegenPass(bool enableBufferInitialization);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The SparseTensorRewriting pass.
|
||||
// The PreSparsificationRewriting pass.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT,
|
||||
bool enableForeach, bool enableConvert);
|
||||
void populatePreSparsificationRewriting(RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<Pass> createSparseTensorRewritePass();
|
||||
std::unique_ptr<Pass> createSparseTensorRewritePass(bool enableRT,
|
||||
bool enableForeach = true,
|
||||
bool enableConvert = true);
|
||||
std::unique_ptr<Pass> createPreSparsificationRewritePass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The PostSparsificationRewriting pass.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void populatePostSparsificationRewriting(RewritePatternSet &patterns,
|
||||
bool enableRT, bool enableForeach,
|
||||
bool enableConvert);
|
||||
|
||||
std::unique_ptr<Pass> createPostSparsificationRewritePass();
|
||||
std::unique_ptr<Pass>
|
||||
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
|
||||
bool enableConvert = true);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Other rewriting rules and passes.
|
||||
|
||||
@@ -11,13 +11,13 @@
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
|
||||
def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {
|
||||
let summary = "Applies sparse tensor rewriting rules prior to sparsification";
|
||||
let description = [{
|
||||
A pass that applies rewriting rules to sparse tensor operations prior
|
||||
to running the actual sparsification pass.
|
||||
}];
|
||||
let constructor = "mlir::createSparseTensorRewritePass()";
|
||||
let constructor = "mlir::createPreSparsificationRewritePass()";
|
||||
let dependentDialects = [
|
||||
"arith::ArithDialect",
|
||||
"bufferization::BufferizationDialect",
|
||||
@@ -26,14 +26,6 @@ def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
|
||||
"scf::SCFDialect",
|
||||
"sparse_tensor::SparseTensorDialect",
|
||||
];
|
||||
let options = [
|
||||
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
|
||||
"true", "Enable runtime library for manipulating sparse tensors">,
|
||||
Option<"enableForeach", "enable-foreach", "bool",
|
||||
"true", "Enable rewriting rules for the foreach operator">,
|
||||
Option<"enableConvert", "enable-convert", "bool",
|
||||
"true", "Enable rewriting rules for the convert operator">,
|
||||
];
|
||||
}
|
||||
|
||||
def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
|
||||
@@ -109,6 +101,31 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
|
||||
];
|
||||
}
|
||||
|
||||
def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
|
||||
let summary = "Applies sparse tensor rewriting rules after sparsification";
|
||||
let description = [{
|
||||
A pass that applies rewriting rules to sparse tensor operations after
|
||||
running the actual sparsification pass.
|
||||
}];
|
||||
let constructor = "mlir::createPostSparsificationRewritePass()";
|
||||
let dependentDialects = [
|
||||
"arith::ArithDialect",
|
||||
"bufferization::BufferizationDialect",
|
||||
"linalg::LinalgDialect",
|
||||
"memref::MemRefDialect",
|
||||
"scf::SCFDialect",
|
||||
"sparse_tensor::SparseTensorDialect",
|
||||
];
|
||||
let options = [
|
||||
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
|
||||
"true", "Enable runtime library for manipulating sparse tensors">,
|
||||
Option<"enableForeach", "enable-foreach", "bool",
|
||||
"true", "Enable rewriting rules for the foreach operator">,
|
||||
Option<"enableConvert", "enable-convert", "bool",
|
||||
"true", "Enable rewriting rules for the convert operator">,
|
||||
];
|
||||
}
|
||||
|
||||
def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
|
||||
let summary = "Convert sparse tensors and primitives to library calls";
|
||||
let description = [{
|
||||
|
||||
@@ -57,8 +57,9 @@ void mlir::sparse_tensor::buildSparseCompiler(
|
||||
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));
|
||||
if (options.testBufferizationAnalysisOnly)
|
||||
return;
|
||||
pm.addPass(createSparseTensorRewritePass(options.enableRuntimeLibrary));
|
||||
pm.addPass(createPreSparsificationRewritePass());
|
||||
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
|
||||
pm.addPass(createPostSparsificationRewritePass(options.enableRuntimeLibrary));
|
||||
if (options.enableRuntimeLibrary) {
|
||||
pm.addPass(createSparseTensorConversionPass(
|
||||
options.sparseTensorConversionOptions()));
|
||||
|
||||
@@ -21,8 +21,9 @@
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_SPARSETENSORREWRITE
|
||||
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
|
||||
#define GEN_PASS_DEF_SPARSIFICATIONPASS
|
||||
#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
|
||||
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
|
||||
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
|
||||
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
|
||||
@@ -38,22 +39,17 @@ namespace {
|
||||
// Passes implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct SparseTensorRewritePass
|
||||
: public impl::SparseTensorRewriteBase<SparseTensorRewritePass> {
|
||||
struct PreSparsificationRewritePass
|
||||
: public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
|
||||
|
||||
SparseTensorRewritePass() = default;
|
||||
SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
|
||||
SparseTensorRewritePass(bool enableRT, bool foreach, bool convert) {
|
||||
enableRuntimeLibrary = enableRT;
|
||||
enableForeach = foreach;
|
||||
enableConvert = convert;
|
||||
}
|
||||
PreSparsificationRewritePass() = default;
|
||||
PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
|
||||
default;
|
||||
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach,
|
||||
enableConvert);
|
||||
populatePreSparsificationRewriting(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
@@ -80,6 +76,28 @@ struct SparsificationPass
|
||||
}
|
||||
};
|
||||
|
||||
struct PostSparsificationRewritePass
|
||||
: public impl::PostSparsificationRewriteBase<
|
||||
PostSparsificationRewritePass> {
|
||||
|
||||
PostSparsificationRewritePass() = default;
|
||||
PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
|
||||
default;
|
||||
PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) {
|
||||
enableRuntimeLibrary = enableRT;
|
||||
enableForeach = foreach;
|
||||
enableConvert = convert;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populatePostSparsificationRewriting(patterns, enableRuntimeLibrary,
|
||||
enableForeach, enableConvert);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct SparseTensorConversionPass
|
||||
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
|
||||
|
||||
@@ -254,15 +272,8 @@ mlir::sparseToSparseConversionStrategy(int32_t flag) {
|
||||
// Pass creation methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() {
|
||||
return std::make_unique<SparseTensorRewritePass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createSparseTensorRewritePass(bool enableRT,
|
||||
bool enableForeach,
|
||||
bool enableConvert) {
|
||||
return std::make_unique<SparseTensorRewritePass>(enableRT, enableForeach,
|
||||
enableConvert);
|
||||
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
|
||||
return std::make_unique<PreSparsificationRewritePass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createSparsificationPass() {
|
||||
@@ -274,6 +285,17 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
|
||||
return std::make_unique<SparsificationPass>(options);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
|
||||
return std::make_unique<PostSparsificationRewritePass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach,
|
||||
bool enableConvert) {
|
||||
return std::make_unique<PostSparsificationRewritePass>(
|
||||
enableRT, enableForeach, enableConvert);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
|
||||
return std::make_unique<SparseTensorConversionPass>();
|
||||
}
|
||||
|
||||
@@ -1021,11 +1021,17 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Methods that add patterns described in this file to a pattern list.
|
||||
//===---------------------------------------------------------------------===//
|
||||
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
|
||||
bool enableRT, bool enableForeach,
|
||||
bool enableConvert) {
|
||||
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
|
||||
ReshapeRewriter<tensor::ExpandShapeOp>,
|
||||
|
||||
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
|
||||
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
|
||||
bool enableRT,
|
||||
bool enableForeach,
|
||||
bool enableConvert) {
|
||||
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
|
||||
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
|
||||
if (enableForeach)
|
||||
patterns.add<ForeachRewriter>(patterns.getContext());
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
|
||||
// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \
|
||||
// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
|
||||
|
||||
// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \
|
||||
// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
// RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=0" \
|
||||
// RUN: --canonicalize --cse | FileCheck %s -check-prefixes=CHECK-AUTO,CHECK
|
||||
|
||||
// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \
|
||||
// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
|
||||
|
||||
#SparseVector64 = #sparse_tensor.encoding<{
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -sparse-tensor-rewrite | FileCheck %s
|
||||
// RUN: mlir-opt %s -post-sparsification-rewrite | FileCheck %s
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{
|
||||
dimLevelType = ["compressed"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" |\
|
||||
// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" |\
|
||||
// RUN: FileCheck %s
|
||||
|
||||
#CSR = #sparse_tensor.encoding<{
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: --sparsification | FileCheck %s
|
||||
|
||||
#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
|
||||
// RUN: mlir-opt %s --linalg-generalize-named-ops --pre-sparsification-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
|
||||
|
||||
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
|
||||
// RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV
|
||||
// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
|
||||
// RUN: --cse --canonicalize | FileCheck %s --check-prefix=CHECK-RWT
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --tensor-copy-insertion --sparse-tensor-rewrite --sparsification --cse | FileCheck %s
|
||||
// RUN: mlir-opt %s --tensor-copy-insertion --pre-sparsification-rewrite --sparsification --cse | FileCheck %s
|
||||
|
||||
#SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user