mirror of
https://github.com/intel/llvm.git
synced 2026-01-29 04:16:38 +08:00
This new pass provides an alternative to the current conversion pass that converts sparse tensor types and sparse primitives to opaque pointers and calls into a runtime support library. This pass will map sparse tensor types to actual data structures and primitives to actual code. In the long run, this new pass will remove our dependence on the support library, avoid the need to link in fully templated and expanded code, and provide much better opportunities for optimization on the generated code. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D132766
83 lines
3.4 KiB
C++
83 lines
3.4 KiB
C++
//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// A pass that converts sparse tensor types and primitives to actual compiler
|
|
// visible buffers and actual compiler IR that implements these primitives on
|
|
// the selected sparse tensor storage schemes. This pass provides an alternative
|
|
// to the SparseTensorConversion pass, eliminating the dependence on a runtime
|
|
// support library, and providing much more opportunities for subsequent
|
|
// compiler optimization of the generated code.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "CodegenUtils.h"
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::sparse_tensor;
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helper methods.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Maps each sparse tensor type to the appropriate buffer.
|
|
static Optional<Type> convertSparseTensorTypes(Type type) {
|
|
if (getSparseTensorEncoding(type) != nullptr) {
|
|
// TODO: this is just a dummy rule to get the ball rolling....
|
|
RankedTensorType rTp = type.cast<RankedTensorType>();
|
|
return MemRefType::get({ShapedType::kDynamicSize}, rTp.getElementType());
|
|
}
|
|
return llvm::None;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Conversion rules.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Sparse conversion rule for returns.
|
|
class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Sparse tensor type conversion into an actual buffer.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
|
|
addConversion([](Type type) { return type; });
|
|
addConversion(convertSparseTensorTypes);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Public method for populating conversion rules.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Populates the given patterns list with conversion rules required for
|
|
/// the sparsification of linear algebra operations.
|
|
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<SparseReturnConverter>(typeConverter, patterns.getContext());
|
|
}
|