mirror of
https://github.com/intel/llvm.git
synced 2026-01-28 01:04:49 +08:00
[mlir][sparse] remove most bufferization.alloc_tensor ops from sparse (#66847)
The only ones left need actual deprecation in bufferization module.
This commit is contained in:
@@ -705,6 +705,7 @@ public:
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for the alloc operator.
|
||||
/// TODO(springerm): remove when bufferization.alloc_tensor is gone
|
||||
class SparseTensorAllocConverter
|
||||
: public OpConversionPattern<bufferization::AllocTensorOp> {
|
||||
public:
|
||||
@@ -764,6 +765,46 @@ private:
|
||||
bool enableBufferInitialization;
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for the empty tensor operator.
|
||||
/// TODO(springerm): remove when bufferization.alloc_tensor is gone
|
||||
class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
SparseTensorEmptyConverter(TypeConverter &typeConverter, MLIRContext *context,
|
||||
bool enableInit)
|
||||
: OpConversionPattern(typeConverter, context),
|
||||
enableBufferInitialization(enableInit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
const auto resType = getSparseTensorType(op);
|
||||
if (!resType.hasEncoding())
|
||||
return failure();
|
||||
|
||||
// Construct allocation for each field.
|
||||
const Location loc = op.getLoc();
|
||||
const Value sizeHint; // none
|
||||
const ValueRange dynSizes = adaptor.getDynamicSizes();
|
||||
const size_t found = dynSizes.size();
|
||||
const int64_t expected = resType.getNumDynamicDims();
|
||||
if (found != static_cast<size_t>(expected))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, llvm::formatv(
|
||||
"Got wrong number of dynamic sizes: Found={0}, Expected={1}",
|
||||
found, expected));
|
||||
SmallVector<Value> fields;
|
||||
createAllocFields(rewriter, loc, resType, dynSizes,
|
||||
enableBufferInitialization, fields, sizeHint);
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
bool enableBufferInitialization;
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for the dealloc operator.
|
||||
class SparseTensorDeallocConverter
|
||||
: public OpConversionPattern<bufferization::DeallocTensorOp> {
|
||||
@@ -1546,6 +1587,6 @@ void mlir::populateSparseTensorCodegenPatterns(
|
||||
patterns.getContext());
|
||||
patterns.add<SparseTensorDeallocConverter>(
|
||||
typeConverter, patterns.getContext(), createSparseDeallocs);
|
||||
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
|
||||
enableBufferInitialization);
|
||||
patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
|
||||
typeConverter, patterns.getContext(), enableBufferInitialization);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user