mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
[mlir][sparse] Replace getSparseTensorType with tryGetSparseTensorType (#109435)
This PR fixes a bug in `SparseTensorDimOpRewriter` when `tensor.dim` has an unranked tensor type. To prevent crashes, we now use `tryGetSparseTensorType` instead of `getSparseTensorType`. Fixes #107807.
This commit is contained in:
@@ -881,25 +881,27 @@ public:
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value srcTensor = op.getSource();
|
||||
const auto srcTp = getSparseTensorType(srcTensor);
|
||||
const auto dstTp = getSparseTensorType(op.getResult());
|
||||
const auto srcTp = tryGetSparseTensorType(srcTensor);
|
||||
const auto dstTp = tryGetSparseTensorType(op.getResult());
|
||||
if (!srcTp || !dstTp)
|
||||
return failure();
|
||||
|
||||
if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
|
||||
!dstTp.hasStaticDimShape())
|
||||
if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
|
||||
!dstTp->hasStaticDimShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> srcSizes;
|
||||
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
|
||||
sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
|
||||
SmallVector<Value> dstSizes;
|
||||
for (Dimension d : dstTp.getDimShape())
|
||||
for (Dimension d : dstTp->getDimShape())
|
||||
dstSizes.push_back(constantIndex(rewriter, loc, d));
|
||||
|
||||
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
|
||||
// Only need an unordered COO buffer if input and output are not sorted
|
||||
// in the same way.
|
||||
Type bufferTp = getBufferType(
|
||||
dstTp.withoutDimToLvl(),
|
||||
!srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
|
||||
dstTp->withoutDimToLvl(),
|
||||
!srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
|
||||
SmallVector<Value> dynSizes;
|
||||
Value buffer = rewriter
|
||||
.create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
|
||||
@@ -917,12 +919,12 @@ public:
|
||||
// followed by an optional
|
||||
// %t = sparse_tensor.cast %tmp
|
||||
// depending on whether the input/output are sorted in the same way.
|
||||
const auto encSrc = srcTp.getEncoding();
|
||||
const auto encSrc = srcTp->getEncoding();
|
||||
ForeachOp foreachOp = rewriter.create<ForeachOp>(
|
||||
loc, srcTensor, buffer,
|
||||
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
|
||||
ValueRange reduc) {
|
||||
const Dimension srcRank = srcTp.getDimRank();
|
||||
const Dimension srcRank = srcTp->getDimRank();
|
||||
SmallVector<Value> srcDcvs;
|
||||
srcDcvs.reserve(srcRank);
|
||||
for (Dimension d = 0; d < srcRank; d++) {
|
||||
@@ -945,7 +947,7 @@ public:
|
||||
collapsedSizes, collapsedDcvs);
|
||||
|
||||
ReassociationIndices expandIdx;
|
||||
for (Dimension i = 0; i < dstTp.getDimRank(); i++)
|
||||
for (Dimension i = 0; i < dstTp->getDimRank(); i++)
|
||||
expandIdx.push_back(i);
|
||||
SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
|
||||
SmallVector<Value> dstDcvs;
|
||||
@@ -958,8 +960,8 @@ public:
|
||||
});
|
||||
|
||||
Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
|
||||
if (bufferTp != dstTp) {
|
||||
auto dstRTT = dstTp.getRankedTensorType();
|
||||
if (bufferTp != *dstTp) {
|
||||
auto dstRTT = dstTp->getRankedTensorType();
|
||||
Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
|
||||
rewriter.create<DeallocTensorOp>(loc, t);
|
||||
t = converted;
|
||||
@@ -1139,13 +1141,13 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
|
||||
LogicalResult matchAndRewrite(tensor::DimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
std::optional<int64_t> dim = op.getConstantIndex();
|
||||
auto stt = getSparseTensorType(op.getSource());
|
||||
if (!dim || !stt.hasEncoding())
|
||||
auto stt = tryGetSparseTensorType(op.getSource());
|
||||
if (!dim || !stt || !stt->hasEncoding())
|
||||
return failure();
|
||||
|
||||
if (stt.isPermutation()) {
|
||||
if (stt->isPermutation()) {
|
||||
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
|
||||
toLvl(stt.getEncoding(), *dim));
|
||||
toLvl(stt->getEncoding(), *dim));
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1157,16 +1159,16 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
|
||||
// computed simply by lvl_size * block_size.
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value> maxLvlCrds;
|
||||
for (Level l = 0; l < stt.getLvlRank(); l++) {
|
||||
for (Level l = 0; l < stt->getLvlRank(); l++) {
|
||||
Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
|
||||
Value maxLvlCrd = rewriter.create<arith::SubIOp>(
|
||||
loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
|
||||
maxLvlCrds.push_back(maxLvlCrd);
|
||||
}
|
||||
|
||||
AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
|
||||
AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
|
||||
Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
|
||||
op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
|
||||
op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
|
||||
maxLvlCrds);
|
||||
|
||||
Value dimSz = rewriter.create<arith::AddIOp>(
|
||||
|
||||
@@ -826,3 +826,19 @@ func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CooPN
|
||||
%0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<?x?xf32, #CooPNo>
|
||||
return %0 : tensor<?x?xf32, #CooPNo>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_tensor_dim_unranked
|
||||
// CHECK: tensor.dim
|
||||
func.func @test_tensor_dim_unranked(%arg0: tensor<*xf32>) -> index {
|
||||
%c = arith.constant 0 : index
|
||||
%0 = tensor.dim %arg0, %c : tensor<*xf32>
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_tensor_reshape_unranked
|
||||
// CHECK: tensor.reshape
|
||||
func.func @test_tensor_reshape_unranked(%src: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
|
||||
%dst = tensor.reshape %src(%shape)
|
||||
: (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
return %dst : tensor<?xf32>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user