mirror of
https://github.com/intel/llvm.git
synced 2026-01-28 19:43:38 +08:00
[mlir][sparse] codegen for sparse dealloc
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D133171
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
|
||||
#include "CodegenUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
@@ -232,7 +233,31 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for pointer accesses.
|
||||
/// Sparse codegen rule for the dealloc operator.
|
||||
class SparseTensorDeallocConverter
|
||||
: public OpConversionPattern<bufferization::DeallocTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto enc = getSparseTensorEncoding(op.getTensor().getType());
|
||||
if (!enc)
|
||||
return failure();
|
||||
// Replace the tuple deallocation with field deallocations.
|
||||
Location loc = op->getLoc();
|
||||
Value tuple = adaptor.getTensor();
|
||||
for (unsigned i = 0, sz = tuple.getType().cast<TupleType>().size(); i < sz;
|
||||
i++) {
|
||||
Value mem = createTupleGet(rewriter, loc, tuple, i);
|
||||
rewriter.create<memref::DeallocOp>(loc, mem);
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for pointer accesses.
|
||||
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@@ -251,7 +276,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for index accesses.
|
||||
/// Sparse codegen rule for index accesses.
|
||||
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@@ -270,7 +295,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for value accesses.
|
||||
/// Sparse codegen rule for value accesses.
|
||||
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@@ -280,7 +305,7 @@ public:
|
||||
// Replace the requested values access with corresponding field.
|
||||
Location loc = op->getLoc();
|
||||
Value tuple = adaptor.getTensor();
|
||||
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
|
||||
unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
|
||||
rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
|
||||
return success();
|
||||
}
|
||||
@@ -306,6 +331,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
|
||||
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
|
||||
SparseToPointersConverter, SparseToIndicesConverter,
|
||||
SparseToValuesConverter>(typeConverter, patterns.getContext());
|
||||
SparseTensorDeallocConverter, SparseToPointersConverter,
|
||||
SparseToIndicesConverter, SparseToValuesConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
@@ -124,8 +124,7 @@ struct SparseTensorConversionPass
|
||||
});
|
||||
// The following operations and dialects may be introduced by the
|
||||
// rewriting rules, and are therefore marked as legal.
|
||||
target.addLegalOp<bufferization::ToMemrefOp, bufferization::ToTensorOp,
|
||||
complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
|
||||
target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
|
||||
linalg::YieldOp, tensor::ExtractOp>();
|
||||
target.addLegalDialect<
|
||||
arith::ArithmeticDialect, bufferization::BufferizationDialect,
|
||||
@@ -160,7 +159,9 @@ struct SparseTensorCodegenPass
|
||||
// Almost everything in the sparse dialect must go!
|
||||
target.addIllegalDialect<SparseTensorDialect>();
|
||||
target.addLegalOp<StorageGetOp, StorageSetOp>();
|
||||
// All dynamic rules below accept new function, call, return.
|
||||
// All dynamic rules below accept new function, call, return, and various
|
||||
// tensor and bufferization operations as legal output of the rewriting
|
||||
// provided that all sparse tensor types have been fully rewritten.
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||
return converter.isSignatureLegal(op.getFunctionType());
|
||||
});
|
||||
@@ -170,6 +171,10 @@ struct SparseTensorCodegenPass
|
||||
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
|
||||
return converter.isLegal(op.getOperandTypes());
|
||||
});
|
||||
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
|
||||
[&](bufferization::DeallocTensorOp op) {
|
||||
return converter.isLegal(op.getTensor().getType());
|
||||
});
|
||||
// Legal dialects may occur in generated code.
|
||||
target.addLegalDialect<arith::ArithmeticDialect,
|
||||
bufferization::BufferizationDialect,
|
||||
|
||||
@@ -141,3 +141,19 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
|
||||
%0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
|
||||
return %0 : memref<?xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sparse_dealloc_csr(
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
|
||||
// CHECK: %[[F0:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<2xindex>
|
||||
// CHECK: memref.dealloc %[[F0]] : memref<2xindex>
|
||||
// CHECK: %[[F1:.*]] = sparse_tensor.storage_get %[[A]][1] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
|
||||
// CHECK: memref.dealloc %[[F1]] : memref<?xi32>
|
||||
// CHECK: %[[F2:.*]] = sparse_tensor.storage_get %[[A]][2] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
|
||||
// CHECK: memref.dealloc %[[F2]] : memref<?xi64>
|
||||
// CHECK: %[[F3:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
|
||||
// CHECK: memref.dealloc %[[F3]] : memref<?xf64>
|
||||
// CHECK: return
|
||||
func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
|
||||
bufferization.dealloc_tensor %arg0 : tensor<?x?xf64, #CSR>
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user