mirror of
https://github.com/intel/llvm.git
synced 2026-02-04 11:38:04 +08:00
[mlir][sparse][NFC] Switch InitOp to bufferization::AllocTensorOp
Now that we have an AllocTensorOp (previously InitTensorOp) in the bufferization dialect, the InitOp in the sparse dialect is no longer needed. Differential Revision: https://reviews.llvm.org/D126180
This commit is contained in:
@@ -459,22 +459,33 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for the init operator.
|
||||
class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
|
||||
/// Sparse conversion rule for the alloc operator.
|
||||
class SparseTensorAllocConverter
|
||||
: public OpConversionPattern<bufferization::AllocTensorOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(InitOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type resType = op.getType();
|
||||
RankedTensorType resType = op.getType();
|
||||
auto enc = getSparseTensorEncoding(resType);
|
||||
if (!enc)
|
||||
return failure();
|
||||
// Gather all dimension sizes as SSA values.
|
||||
SmallVector<Value> sizes;
|
||||
unsigned int operandCtr = 0;
|
||||
for (int64_t i = 0; i < resType.getRank(); ++i) {
|
||||
if (resType.isDynamicDim(i)) {
|
||||
sizes.push_back(adaptor.getOperands()[operandCtr++]);
|
||||
} else {
|
||||
sizes.push_back(rewriter.create<arith::ConstantIndexOp>(
|
||||
op.getLoc(), op.getStaticSize(i)));
|
||||
}
|
||||
}
|
||||
// Generate the call to construct empty tensor. The sizes are
|
||||
// explicitly defined by the arguments to the init operator.
|
||||
// explicitly defined by the arguments to the alloc operator.
|
||||
SmallVector<Value, 8> params;
|
||||
ShapedType stp = resType.cast<ShapedType>();
|
||||
newParams(rewriter, params, op, stp, enc, Action::kEmpty,
|
||||
adaptor.getOperands());
|
||||
newParams(rewriter, params, op, stp, enc, Action::kEmpty, sizes);
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
|
||||
return success();
|
||||
}
|
||||
@@ -912,7 +923,7 @@ void mlir::populateSparseTensorConversionPatterns(
|
||||
const SparseTensorConversionOptions &options) {
|
||||
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
|
||||
SparseCastConverter, SparseTensorNewConverter,
|
||||
SparseTensorInitConverter, SparseTensorReleaseConverter,
|
||||
SparseTensorAllocConverter, SparseTensorReleaseConverter,
|
||||
SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
|
||||
SparseTensorToValuesConverter, SparseTensorLoadConverter,
|
||||
SparseTensorLexInsertConverter, SparseTensorExpandConverter,
|
||||
|
||||
Reference in New Issue
Block a user