[mlir][sparse] replace support lib conversion with actual MLIR codegen

Rationale:
Passing in a pointer to the memref data in order to implement the
dense to sparse conversion was a bit too low-level. This revision
improves upon that approach with a cleaner solution of generating
a loop nest in MLIR code itself that prepares the COO object before
passing it to our "swiss army knife" setup.  This is much more
intuitive *and* now also allows for dynamic shapes.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D108491
This commit is contained in:
Aart Bik
2021-08-23 10:29:19 -07:00
parent eff11176c5
commit 236a90802d
5 changed files with 231 additions and 122 deletions

View File

@@ -99,6 +99,9 @@ struct SparseTensorConversionPass
ConversionTarget target(*ctx);
target.addIllegalOp<NewOp, ConvertOp, ToPointersOp, ToIndicesOp, ToValuesOp,
ToTensorOp>();
// All dynamic rules below accept new function, call, return, and dimop
// operations as legal output of the rewriting provided that all sparse
// tensor types have been fully rewritten.
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
@@ -106,8 +109,15 @@ struct SparseTensorConversionPass
});
target.addDynamicallyLegalOp<ReturnOp>(
[&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
target.addLegalOp<ConstantOp, tensor::CastOp, memref::BufferCastOp,
memref::CastOp>();
target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
return converter.isLegal(op.getOperandTypes());
});
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
target.addLegalOp<ConstantOp, tensor::CastOp, tensor::ExtractOp>();
target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();
// Populate with rules and apply rewriting rules.
populateFuncOpTypeConversionPattern(patterns, converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateSparseTensorConversionPatterns(converter, patterns);