mirror of
https://github.com/intel/llvm.git
synced 2026-01-30 05:55:35 +08:00
[mlir][sparse] fix bugs when computing the memory size when lowering pack op.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D151481
This commit is contained in:
@@ -1242,10 +1242,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
});
|
||||
|
||||
MutSparseTensorDescriptor desc(stt, fields);
|
||||
Value c0 = constantIndex(rewriter, loc, 0);
|
||||
Value c1 = constantIndex(rewriter, loc, 1);
|
||||
Value c2 = constantIndex(rewriter, loc, 2);
|
||||
Value posBack = c1; // index to the last value in the postion array
|
||||
Value memSize = c2; // memory size for current array
|
||||
Value posBack = c0; // index to the last value in the postion array
|
||||
Value memSize = c1; // memory size for current array
|
||||
|
||||
Level trailCOOStart = getCOOStart(stt.getEncoding());
|
||||
Level trailCOORank = stt.getLvlRank() - trailCOOStart;
|
||||
@@ -1266,7 +1267,7 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
DimLevelType dlt = stt.getLvlType(lvl);
|
||||
// Simply forwards the position index when this is a dense level.
|
||||
if (isDenseDLT(dlt)) {
|
||||
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, posBack);
|
||||
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
|
||||
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
|
||||
continue;
|
||||
}
|
||||
@@ -1276,6 +1277,10 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
if (isCompressedWithHiDLT(dlt)) {
|
||||
memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
|
||||
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
|
||||
} else {
|
||||
assert(isCompressedDLT(dlt));
|
||||
posBack = memSize;
|
||||
memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
|
||||
}
|
||||
desc.setPosMemSize(rewriter, loc, lvl, memSize);
|
||||
// The last value in position array is the memory size for next level.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse | FileCheck %s
|
||||
// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
|
||||
|
||||
#COO = #sparse_tensor.encoding<{
|
||||
lvlTypes = ["compressed-nu", "singleton"],
|
||||
@@ -9,25 +9,25 @@
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xf64>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xindex>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: tensor<6x2xi32>)
|
||||
// CHECK-DAG: %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_1]] : memref<2xindex>
|
||||
// CHECK-DAG: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x2xi32>
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
|
||||
// CHECK-DAG: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref<?xi32>
|
||||
// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
|
||||
// CHECK-DAG: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref<?xf64>
|
||||
// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init
|
||||
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 100 : index
|
||||
// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] lvl_sz at 0 with %[[VAL_13]]
|
||||
// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_12]]
|
||||
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_12]] : index
|
||||
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 100 : index
|
||||
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_1]] : memref<2xindex>
|
||||
// CHECK-DAG: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<2xindex> to memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x2xi32>
|
||||
// CHECK-DAG: %[[VAL_9:.*]] = memref.collapse_shape %[[VAL_8]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
|
||||
// CHECK-DAG: %[[VAL_10:.*]] = memref.cast %[[VAL_9]] : memref<12xi32> to memref<?xi32>
|
||||
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
|
||||
// CHECK-DAG: %[[VAL_12:.*]] = memref.cast %[[VAL_11]] : memref<6xf64> to memref<?xf64>
|
||||
// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
|
||||
// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] lvl_sz at 0 with %[[VAL_4]]
|
||||
// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_3]]
|
||||
// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
|
||||
// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
|
||||
// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] crd_mem_sz at 0 with %[[VAL_17]]
|
||||
// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_13]]
|
||||
// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_4]]
|
||||
// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] val_mem_sz with %[[VAL_16]]
|
||||
// CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_20]]
|
||||
// CHECK: return %[[VAL_7]], %[[VAL_10]], %[[VAL_12]], %[[VAL_20]]
|
||||
// CHECK: }
|
||||
func.func @sparse_pack(%values: tensor<6xf64>, %pos:tensor<2xindex>, %coordinates: tensor<6x2xi32>)
|
||||
-> tensor<100x100xf64, #COO> {
|
||||
|
||||
@@ -81,6 +81,10 @@ module {
|
||||
%s5= sparse_tensor.pack %data, %pos32, %index32 : tensor<3xf64>, tensor<2xi32>, tensor<3x2xi32>
|
||||
to tensor<10x10xf64, #SortedCOOI32>
|
||||
|
||||
%csr_data = arith.constant dense<
|
||||
[ 1.0, 2.0, 3.0, 4.0]
|
||||
> : tensor<4xf64>
|
||||
|
||||
%csr_pos32 = arith.constant dense<
|
||||
[0, 1, 3]
|
||||
> : tensor<3xi32>
|
||||
@@ -88,7 +92,7 @@ module {
|
||||
%csr_index32 = arith.constant dense<
|
||||
[1, 0, 1]
|
||||
> : tensor<3xi32>
|
||||
%csr= sparse_tensor.pack %data, %csr_pos32, %csr_index32 : tensor<3xf64>, tensor<3xi32>, tensor<3xi32>
|
||||
%csr= sparse_tensor.pack %csr_data, %csr_pos32, %csr_index32 : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
|
||||
to tensor<2x2xf64, #CSR>
|
||||
|
||||
%bdata = arith.constant dense<
|
||||
@@ -164,6 +168,16 @@ module {
|
||||
vector.print %v: f64
|
||||
}
|
||||
|
||||
%d_csr = tensor.empty() : tensor<4xf64>
|
||||
%p_csr = tensor.empty() : tensor<3xi32>
|
||||
%i_csr = tensor.empty() : tensor<3xi32>
|
||||
%rd_csr, %rp_csr, %ri_csr = sparse_tensor.unpack %csr : tensor<2x2xf64, #CSR>
|
||||
outs(%d_csr, %p_csr, %i_csr : tensor<4xf64>, tensor<3xi32>, tensor<3xi32>)
|
||||
-> tensor<4xf64>, tensor<3xi32>, tensor<3xi32>
|
||||
|
||||
// CHECK-NEXT: ( 1, 2, 3, {{.*}} )
|
||||
%vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<4xf64>
|
||||
vector.print %vd_csr : vector<4xf64>
|
||||
|
||||
// CHECK-NEXT:1
|
||||
// CHECK-NEXT:2
|
||||
|
||||
Reference in New Issue
Block a user