mirror of
https://github.com/intel/llvm.git
synced 2026-02-08 00:50:03 +08:00
[sparse] allow unpack op to return any integer type. (#66161)
This commit is contained in:
@@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
|
||||
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
|
||||
Results<(outs TensorOf<[AnyType]>:$ret_values,
|
||||
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
|
||||
Index:$val_len,
|
||||
Variadic<Index>:$lvl_lens)> {
|
||||
AnySignlessIntegerOrIndex:$val_len,
|
||||
Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> {
|
||||
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
|
||||
|
||||
let description = [{
|
||||
|
||||
@@ -1323,7 +1323,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
// TODO: maybe change unpack/pack operation instead to be
|
||||
// consistent.
|
||||
retMem.insert(retMem.begin(), dst);
|
||||
retLen.insert(retLen.begin(), sz);
|
||||
Type valLenTp = op.getValLen().getType();
|
||||
retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
|
||||
} else {
|
||||
assert(fKind == SparseTensorFieldKind::PosMemRef ||
|
||||
fKind == SparseTensorFieldKind::CrdMemRef);
|
||||
@@ -1334,7 +1335,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
src = desc.getMemRefField(fid);
|
||||
dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
|
||||
retMem.push_back(dst);
|
||||
retLen.push_back(sz);
|
||||
// Retrieves the corresponding level length type.
|
||||
Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
|
||||
retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
|
||||
}
|
||||
Value flatOut = dst;
|
||||
if (dst.getType().getRank() != 1) {
|
||||
|
||||
@@ -178,7 +178,7 @@ module {
|
||||
%i_csr = tensor.empty() : tensor<3xi32>
|
||||
%rd_csr, %rp_csr, %ri_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
|
||||
outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
|
||||
-> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (index, index)
|
||||
-> tensor<4xf64>, (tensor<3xi32>, tensor<3xi32>), index, (i32, i64)
|
||||
|
||||
// CHECK-NEXT: ( 1, 2, 3, {{.*}} )
|
||||
%vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
|
||||
@@ -203,7 +203,7 @@ module {
|
||||
%oi = tensor.empty() : tensor<3x2xi32>
|
||||
%d, %p, %i, %dl, %pl, %il = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
|
||||
outs(%od, %op, %oi : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>)
|
||||
-> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (index, index)
|
||||
-> tensor<3xf64>, (tensor<2xi32>, tensor<3x2xi32>), index, (i32, i64)
|
||||
|
||||
// CHECK-NEXT: ( 1, 2, 3 )
|
||||
%vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
|
||||
@@ -219,7 +219,7 @@ module {
|
||||
%boi = tensor.empty() : tensor<6x2xindex>
|
||||
%bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
|
||||
outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
|
||||
-> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (index, index)
|
||||
-> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64)
|
||||
|
||||
// CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
|
||||
%vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
|
||||
@@ -231,7 +231,7 @@ module {
|
||||
%vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
|
||||
vector.print %vbi : vector<6x2xindex>
|
||||
// CHECK-NEXT: 10
|
||||
vector.print %li : index
|
||||
vector.print %li : i64
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user