[mlir][sparse] codegen for sparse dealloc

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133171
This commit is contained in:
Aart Bik
2022-09-01 17:18:56 -07:00
parent 11881a8f3f
commit 2ddfacd95c
3 changed files with 56 additions and 9 deletions

View File

@@ -17,6 +17,7 @@
#include "CodegenUtils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -232,7 +233,31 @@ public:
}
};
/// Sparse conversion rule for pointer accesses.
/// Sparse codegen rule for the dealloc operator.
class SparseTensorDeallocConverter
: public OpConversionPattern<bufferization::DeallocTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto enc = getSparseTensorEncoding(op.getTensor().getType());
if (!enc)
return failure();
// Replace the tuple deallocation with field deallocations.
Location loc = op->getLoc();
Value tuple = adaptor.getTensor();
for (unsigned i = 0, sz = tuple.getType().cast<TupleType>().size(); i < sz;
i++) {
Value mem = createTupleGet(rewriter, loc, tuple, i);
rewriter.create<memref::DeallocOp>(loc, mem);
}
rewriter.eraseOp(op);
return success();
}
};
/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -251,7 +276,7 @@ public:
}
};
/// Sparse conversion rule for index accesses.
/// Sparse codegen rule for index accesses.
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -270,7 +295,7 @@ public:
}
};
/// Sparse conversion rule for value accesses.
/// Sparse codegen rule for value accesses.
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -280,7 +305,7 @@ public:
// Replace the requested values access with corresponding field.
Location loc = op->getLoc();
Value tuple = adaptor.getTensor();
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
return success();
}
@@ -306,6 +331,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
SparseToPointersConverter, SparseToIndicesConverter,
SparseToValuesConverter>(typeConverter, patterns.getContext());
SparseTensorDeallocConverter, SparseToPointersConverter,
SparseToIndicesConverter, SparseToValuesConverter>(
typeConverter, patterns.getContext());
}

View File

@@ -124,8 +124,7 @@ struct SparseTensorConversionPass
});
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
linalg::YieldOp, tensor::ExtractOp>();
target.addLegalDialect<
arith::ArithmeticDialect, bufferization::BufferizationDialect,
@@ -160,7 +159,9 @@ struct SparseTensorCodegenPass
// Almost everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addLegalOp<StorageGetOp, StorageSetOp>();
// All dynamic rules below accept new function, call, return.
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
@@ -170,6 +171,10 @@ struct SparseTensorCodegenPass
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
[&](bufferization::DeallocTensorOp op) {
return converter.isLegal(op.getTensor().getType());
});
// Legal dialects may occur in generated code.
target.addLegalDialect<arith::ArithmeticDialect,
bufferization::BufferizationDialect,

View File

@@ -141,3 +141,19 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
%0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
return %0 : memref<?xf64>
}
// CHECK-LABEL: func @sparse_dealloc_csr(
// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
// CHECK: %[[F0:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<2xindex>
// CHECK: memref.dealloc %[[F0]] : memref<2xindex>
// CHECK: %[[F1:.*]] = sparse_tensor.storage_get %[[A]][1] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
// CHECK: memref.dealloc %[[F1]] : memref<?xi32>
// CHECK: %[[F2:.*]] = sparse_tensor.storage_get %[[A]][2] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
// CHECK: memref.dealloc %[[F2]] : memref<?xi64>
// CHECK: %[[F3:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
// CHECK: memref.dealloc %[[F3]] : memref<?xf64>
// CHECK: return
func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
bufferization.dealloc_tensor %arg0 : tensor<?x?xf64, #CSR>
return
}