[sparse] allow unpack op to return any integer type. (#66161)

This commit is contained in:
Peiming Liu
2023-09-12 17:27:51 -07:00
committed by GitHub
parent 749ec26d83
commit 64df1c08d0
3 changed files with 11 additions and 8 deletions

View File

@@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
Results<(outs TensorOf<[AnyType]>:$ret_values,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
Index:$val_len,
Variadic<Index>:$lvl_lens)> {
AnySignlessIntegerOrIndex:$val_len,
Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> {
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
let description = [{

View File

@@ -1323,7 +1323,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// TODO: maybe change unpack/pack operation instead to be
// consistent.
retMem.insert(retMem.begin(), dst);
retLen.insert(retLen.begin(), sz);
Type valLenTp = op.getValLen().getType();
retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1334,7 +1335,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
src = desc.getMemRefField(fid);
dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
retMem.push_back(dst);
retLen.push_back(sz);
// Retrieves the corresponding level length type.
Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
}
Value flatOut = dst;
if (dst.getType().getRank() != 1) {

View File

@@ -178,7 +178,7 @@ module {
%i_csr = tensor.empty() : tensor<3xi32>
%rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
-> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (index, index)
-> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
// CHECK-NEXT: ( 1, 2, 3, {{.*}} )
%vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
@@ -203,7 +203,7 @@ module {
%oi = tensor.empty() : tensor<3x2xi32>
%d, %p, %i, %dl, %pl, %il = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
-> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (index, index)
-> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (i32, i64)
// CHECK-NEXT: ( 1, 2, 3 )
%vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
@@ -219,7 +219,7 @@ module {
%boi = tensor.empty() : tensor<6x2xindex>
%bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
-> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (index, index)
-> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64)
// CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
%vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
@@ -231,7 +231,7 @@ module {
%vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
vector.print %vbi : vector<6x2xindex>
// CHECK-NEXT: 10
vector.print %li : index
vector.print %li : i64
return
}