2021-09-22 11:48:57 -07:00
|
|
|
//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
|
2021-05-03 20:55:12 -07:00
|
|
|
//
|
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2022-04-22 14:17:08 -07:00
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
2021-11-25 11:42:16 +01:00
|
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
2022-05-19 19:33:33 -07:00
|
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
2022-02-28 14:25:39 -08:00
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
2022-02-26 14:49:54 -08:00
|
|
|
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
2021-05-03 20:55:12 -07:00
|
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
|
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
2022-07-27 00:16:20 +00:00
|
|
|
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
2021-05-03 20:55:12 -07:00
|
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
|
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
2022-04-22 14:17:08 -07:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2021-05-03 20:55:12 -07:00
|
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
|
|
2022-08-31 10:16:29 +02:00
|
|
|
namespace mlir {
|
|
|
|
|
#define GEN_PASS_DEF_SPARSIFICATIONPASS
|
|
|
|
|
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
|
|
|
|
|
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
|
|
|
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
|
|
|
|
|
} // namespace mlir
|
|
|
|
|
|
2021-05-03 20:55:12 -07:00
|
|
|
using namespace mlir;
|
2021-05-10 10:34:21 -07:00
|
|
|
using namespace mlir::sparse_tensor;
|
2021-05-03 20:55:12 -07:00
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Passes implementation.
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2022-08-31 10:16:29 +02:00
|
|
|
struct SparsificationPass
|
|
|
|
|
: public impl::SparsificationPassBase<SparsificationPass> {
|
2021-05-03 20:55:12 -07:00
|
|
|
|
|
|
|
|
SparsificationPass() = default;
|
2022-01-02 22:01:50 +00:00
|
|
|
SparsificationPass(const SparsificationPass &pass) = default;
|
2022-01-26 16:44:32 -08:00
|
|
|
SparsificationPass(const SparsificationOptions &options) {
|
2022-09-04 01:39:35 +00:00
|
|
|
parallelization = options.parallelizationStrategy;
|
|
|
|
|
vectorization = options.vectorizationStrategy;
|
2022-01-26 16:44:32 -08:00
|
|
|
vectorLength = options.vectorLength;
|
|
|
|
|
enableSIMDIndex32 = options.enableSIMDIndex32;
|
2021-12-02 15:09:33 +00:00
|
|
|
enableVLAVectorization = options.enableVLAVectorization;
|
2022-09-09 18:37:59 +00:00
|
|
|
enableRuntimeLibrary = options.enableRuntimeLibrary;
|
2021-05-03 20:55:12 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void runOnOperation() override {
|
|
|
|
|
auto *ctx = &getContext();
|
2022-07-15 16:41:02 -07:00
|
|
|
RewritePatternSet prePatterns(ctx);
|
2021-05-03 20:55:12 -07:00
|
|
|
// Translate strategy flags to strategy options.
|
2022-09-04 01:39:35 +00:00
|
|
|
SparsificationOptions options(parallelization, vectorization, vectorLength,
|
2022-09-09 18:37:59 +00:00
|
|
|
enableSIMDIndex32, enableVLAVectorization,
|
|
|
|
|
enableRuntimeLibrary);
|
|
|
|
|
// Apply pre-rewriting.
|
|
|
|
|
populateSparseTensorRewriting(prePatterns, options.enableRuntimeLibrary);
|
|
|
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
|
2022-07-15 16:41:02 -07:00
|
|
|
// Apply sparsification and vector cleanup rewriting.
|
|
|
|
|
RewritePatternSet patterns(ctx);
|
2021-05-03 20:55:12 -07:00
|
|
|
populateSparsificationPatterns(patterns, options);
|
|
|
|
|
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
|
2022-09-16 15:22:48 -07:00
|
|
|
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
|
2021-05-03 20:55:12 -07:00
|
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct SparseTensorConversionPass
|
2022-08-31 10:16:29 +02:00
|
|
|
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
|
2022-03-18 19:10:40 -07:00
|
|
|
|
|
|
|
|
SparseTensorConversionPass() = default;
|
|
|
|
|
SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
|
|
|
|
|
SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
|
|
|
|
|
sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
|
|
|
|
|
}
|
|
|
|
|
|
2021-05-03 20:55:12 -07:00
|
|
|
void runOnOperation() override {
|
|
|
|
|
auto *ctx = &getContext();
|
2021-05-10 10:34:21 -07:00
|
|
|
RewritePatternSet patterns(ctx);
|
2022-08-26 13:49:07 -07:00
|
|
|
SparseTensorTypeToPtrConverter converter;
|
2021-05-03 20:55:12 -07:00
|
|
|
ConversionTarget target(*ctx);
|
2021-10-20 12:47:31 -07:00
|
|
|
// Everything in the sparse dialect must go!
|
|
|
|
|
target.addIllegalDialect<SparseTensorDialect>();
|
2022-06-21 14:13:14 -07:00
|
|
|
// All dynamic rules below accept new function, call, return, and various
|
|
|
|
|
// tensor and bufferization operations as legal output of the rewriting
|
|
|
|
|
// provided that all sparse tensor types have been fully rewritten.
|
2022-04-18 11:53:47 -07:00
|
|
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
2022-03-15 17:36:15 -07:00
|
|
|
return converter.isSignatureLegal(op.getFunctionType());
|
|
|
|
|
});
|
2022-02-26 14:49:54 -08:00
|
|
|
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
|
2021-05-10 10:34:21 -07:00
|
|
|
return converter.isSignatureLegal(op.getCalleeType());
|
|
|
|
|
});
|
2022-02-26 14:49:54 -08:00
|
|
|
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
|
|
|
|
|
return converter.isLegal(op.getOperandTypes());
|
|
|
|
|
});
|
2021-08-23 10:29:19 -07:00
|
|
|
target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
|
|
|
|
|
return converter.isLegal(op.getOperandTypes());
|
|
|
|
|
});
|
2021-10-20 12:47:31 -07:00
|
|
|
target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
|
2022-07-10 21:19:11 -07:00
|
|
|
return converter.isLegal(op.getSource().getType()) &&
|
|
|
|
|
converter.isLegal(op.getDest().getType());
|
2021-10-20 12:47:31 -07:00
|
|
|
});
|
2022-07-01 16:57:40 -07:00
|
|
|
target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
|
|
|
|
|
[&](tensor::ExpandShapeOp op) {
|
2022-07-10 21:19:11 -07:00
|
|
|
return converter.isLegal(op.getSrc().getType()) &&
|
|
|
|
|
converter.isLegal(op.getResult().getType());
|
2022-07-01 16:57:40 -07:00
|
|
|
});
|
|
|
|
|
target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
|
|
|
|
|
[&](tensor::CollapseShapeOp op) {
|
2022-07-10 21:19:11 -07:00
|
|
|
return converter.isLegal(op.getSrc().getType()) &&
|
|
|
|
|
converter.isLegal(op.getResult().getType());
|
2022-07-01 16:57:40 -07:00
|
|
|
});
|
2022-06-21 14:13:14 -07:00
|
|
|
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
|
|
|
|
|
[&](bufferization::AllocTensorOp op) {
|
|
|
|
|
return converter.isLegal(op.getType());
|
|
|
|
|
});
|
2022-07-19 09:13:53 +02:00
|
|
|
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
|
|
|
|
|
[&](bufferization::DeallocTensorOp op) {
|
|
|
|
|
return converter.isLegal(op.getTensor().getType());
|
|
|
|
|
});
|
2021-08-23 10:29:19 -07:00
|
|
|
// The following operations and dialects may be introduced by the
|
|
|
|
|
// rewriting rules, and are therefore marked as legal.
|
2022-09-01 17:18:56 -07:00
|
|
|
target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
|
2022-07-08 21:12:25 -07:00
|
|
|
linalg::YieldOp, tensor::ExtractOp>();
|
|
|
|
|
target.addLegalDialect<
|
|
|
|
|
arith::ArithmeticDialect, bufferization::BufferizationDialect,
|
|
|
|
|
LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
|
2022-03-18 19:10:40 -07:00
|
|
|
// Translate strategy flags to strategy options.
|
|
|
|
|
SparseTensorConversionOptions options(
|
|
|
|
|
sparseToSparseConversionStrategy(sparseToSparse));
|
2021-08-23 10:29:19 -07:00
|
|
|
// Populate with rules and apply rewriting rules.
|
2022-04-18 11:53:47 -07:00
|
|
|
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
|
|
|
|
converter);
|
2021-05-10 10:34:21 -07:00
|
|
|
populateCallOpTypeConversionPattern(patterns, converter);
|
2022-07-27 00:16:20 +00:00
|
|
|
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
|
|
|
|
|
target);
|
2022-03-18 19:10:40 -07:00
|
|
|
populateSparseTensorConversionPatterns(converter, patterns, options);
|
2021-05-03 20:55:12 -07:00
|
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
2021-05-10 10:34:21 -07:00
|
|
|
std::move(patterns))))
|
2021-05-03 20:55:12 -07:00
|
|
|
signalPassFailure();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2022-08-26 13:49:07 -07:00
|
|
|
struct SparseTensorCodegenPass
|
2022-08-31 10:16:29 +02:00
|
|
|
: public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
|
2022-08-26 13:49:07 -07:00
|
|
|
|
|
|
|
|
SparseTensorCodegenPass() = default;
|
|
|
|
|
SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
|
|
|
|
|
|
|
|
|
|
void runOnOperation() override {
|
|
|
|
|
auto *ctx = &getContext();
|
|
|
|
|
RewritePatternSet patterns(ctx);
|
|
|
|
|
SparseTensorTypeToBufferConverter converter;
|
|
|
|
|
ConversionTarget target(*ctx);
|
2022-09-27 17:06:20 -07:00
|
|
|
// Most ops in the sparse dialect must go!
|
2022-08-26 13:49:07 -07:00
|
|
|
target.addIllegalDialect<SparseTensorDialect>();
|
2022-09-27 17:06:20 -07:00
|
|
|
target.addLegalOp<SortOp>();
|
2022-09-01 17:18:56 -07:00
|
|
|
// All dynamic rules below accept new function, call, return, and various
|
|
|
|
|
// tensor and bufferization operations as legal output of the rewriting
|
|
|
|
|
// provided that all sparse tensor types have been fully rewritten.
|
2022-08-26 13:49:07 -07:00
|
|
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
|
|
|
|
return converter.isSignatureLegal(op.getFunctionType());
|
|
|
|
|
});
|
|
|
|
|
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
|
|
|
|
|
return converter.isSignatureLegal(op.getCalleeType());
|
|
|
|
|
});
|
|
|
|
|
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
|
|
|
|
|
return converter.isLegal(op.getOperandTypes());
|
|
|
|
|
});
|
2022-09-02 17:54:17 -07:00
|
|
|
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
|
|
|
|
|
[&](bufferization::AllocTensorOp op) {
|
|
|
|
|
return converter.isLegal(op.getType());
|
|
|
|
|
});
|
2022-09-01 17:18:56 -07:00
|
|
|
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
|
|
|
|
|
[&](bufferization::DeallocTensorOp op) {
|
|
|
|
|
return converter.isLegal(op.getTensor().getType());
|
|
|
|
|
});
|
2022-09-07 14:34:04 -07:00
|
|
|
// The following operations and dialects may be introduced by the
|
|
|
|
|
// codegen rules, and are therefore marked as legal.
|
|
|
|
|
target.addLegalOp<linalg::FillOp>();
|
2022-08-31 18:22:04 -07:00
|
|
|
target.addLegalDialect<arith::ArithmeticDialect,
|
|
|
|
|
bufferization::BufferizationDialect,
|
|
|
|
|
memref::MemRefDialect, scf::SCFDialect>();
|
2022-09-01 20:34:05 +00:00
|
|
|
target.addLegalOp<UnrealizedConversionCastOp>();
|
2022-09-01 17:06:31 +00:00
|
|
|
// Populate with rules and apply rewriting rules.
|
|
|
|
|
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
|
|
|
|
converter);
|
|
|
|
|
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
|
|
|
|
|
target);
|
2022-09-07 00:49:44 +00:00
|
|
|
populateSparseTensorCodegenPatterns(converter, patterns);
|
2022-09-01 17:06:31 +00:00
|
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
|
|
|
std::move(patterns))))
|
|
|
|
|
signalPassFailure();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2021-12-07 18:27:58 +00:00
|
|
|
} // namespace
|
2021-05-03 20:55:12 -07:00
|
|
|
|
2022-08-26 13:49:07 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Strategy flag methods.
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2022-03-18 19:10:40 -07:00
|
|
|
SparseToSparseConversionStrategy
|
|
|
|
|
mlir::sparseToSparseConversionStrategy(int32_t flag) {
|
|
|
|
|
switch (flag) {
|
|
|
|
|
default:
|
|
|
|
|
return SparseToSparseConversionStrategy::kAuto;
|
|
|
|
|
case 1:
|
|
|
|
|
return SparseToSparseConversionStrategy::kViaCOO;
|
|
|
|
|
case 2:
|
|
|
|
|
return SparseToSparseConversionStrategy::kDirect;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2022-08-26 13:49:07 -07:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Pass creation methods.
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2021-05-03 20:55:12 -07:00
|
|
|
std::unique_ptr<Pass> mlir::createSparsificationPass() {
|
|
|
|
|
return std::make_unique<SparsificationPass>();
|
|
|
|
|
}
|
|
|
|
|
|
2022-01-26 16:44:32 -08:00
|
|
|
std::unique_ptr<Pass>
|
|
|
|
|
mlir::createSparsificationPass(const SparsificationOptions &options) {
|
|
|
|
|
return std::make_unique<SparsificationPass>(options);
|
|
|
|
|
}
|
|
|
|
|
|
2021-05-03 20:55:12 -07:00
|
|
|
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
|
|
|
|
|
return std::make_unique<SparseTensorConversionPass>();
|
|
|
|
|
}
|
2022-03-18 19:10:40 -07:00
|
|
|
|
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
|
|
|
|
|
const SparseTensorConversionOptions &options) {
|
|
|
|
|
return std::make_unique<SparseTensorConversionPass>(options);
|
|
|
|
|
}
|
2022-08-26 13:49:07 -07:00
|
|
|
|
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
|
|
|
|
|
return std::make_unique<SparseTensorCodegenPass>();
|
|
|
|
|
}
|