//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // A pass that converts sparse tensor types and primitives to actual compiler // visible buffers and actual compiler IR that implements these primitives on // the selected sparse tensor storage schemes. This pass provides an alternative // to the SparseTensorConversion pass, eliminating the dependence on a runtime // support library, and providing much more opportunities for subsequent // compiler optimization of the generated code. // //===----------------------------------------------------------------------===// #include "CodegenUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::sparse_tensor; namespace { //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// /// Maps each sparse tensor type to the appropriate buffer. static Optional convertSparseTensorTypes(Type type) { if (getSparseTensorEncoding(type) != nullptr) { // TODO: this is just a dummy rule to get the ball rolling.... RankedTensorType rTp = type.cast(); return MemRefType::get({ShapedType::kDynamicSize}, rTp.getElementType()); } return llvm::None; } //===----------------------------------------------------------------------===// // Conversion rules. //===----------------------------------------------------------------------===// /// Sparse conversion rule for returns. class SparseReturnConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Sparse tensor type conversion into an actual buffer. //===----------------------------------------------------------------------===// mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { addConversion([](Type type) { return type; }); addConversion(convertSparseTensorTypes); } //===----------------------------------------------------------------------===// // Public method for populating conversion rules. //===----------------------------------------------------------------------===// /// Populates the given patterns list with conversion rules required for /// the sparsification of linear algebra operations. void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); }