Files
llvm/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Aart Bik 86b22d3120 [mlir][sparse] start a sparse codegen conversion pass
This new pass provides an alternative to the current conversion pass
that converts sparse tensor types and sparse primitives to opaque pointers
and calls into a runtime support library. This pass will map sparse tensor
types to actual data structures and primitives to actual code. In the long
run, this new pass will remove our dependence on the support library, avoid
the need to link in fully templated and expanded code, and provide much better
opportunities for optimization on the generated code.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D132766
2022-08-29 09:39:33 -07:00

83 lines
3.4 KiB
C++

//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A pass that converts sparse tensor types and primitives to actual compiler
// visible buffers and actual compiler IR that implements these primitives on
// the selected sparse tensor storage schemes. This pass provides an alternative
// to the SparseTensorConversion pass, eliminating the dependence on a runtime
// support library, and providing much more opportunities for subsequent
// compiler optimization of the generated code.
//
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
namespace {
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
/// Maps each sparse tensor type to the appropriate buffer.
static Optional<Type> convertSparseTensorTypes(Type type) {
if (getSparseTensorEncoding(type) != nullptr) {
// TODO: this is just a dummy rule to get the ball rolling....
RankedTensorType rTp = type.cast<RankedTensorType>();
return MemRefType::get({ShapedType::kDynamicSize}, rTp.getElementType());
}
return llvm::None;
}
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
/// Sparse conversion rule for returns.
class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Sparse tensor type conversion into an actual buffer.
//===----------------------------------------------------------------------===//
mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
addConversion([](Type type) { return type; });
addConversion(convertSparseTensorTypes);
}
//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//
/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter>(typeConverter, patterns.getContext());
}