From c4e5a8a4d3ef0948384d9411ea1e44fc113e5b5c Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 7 May 2024 19:01:36 -0700 Subject: [PATCH] [mlir][sparse] support 'batch' dimensions in sparse_tensor.print (#91411) --- .../Transforms/SparseTensorCodegen.cpp | 12 ++- .../Transforms/SparseTensorRewriting.cpp | 66 ++++++++++------- .../SparseTensor/CPU/sparse_pack_d.mlir | 12 +-- .../SparseTensor/CPU/sparse_print_3d.mlir | 74 +++++++++++++++++++ 4 files changed, 130 insertions(+), 34 deletions(-) create mode 100755 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index d9b203a88648..164e722c45db 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc, /// Generates a subview into the sizes. static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) { - auto elemTp = llvm::cast(mem.getType()).getElementType(); + auto memTp = llvm::cast(mem.getType()); + // For higher-dimensional memrefs, we assume that the innermost + // dimension is always of the right size. + // TODO: generate complex truncating view here too? + if (memTp.getRank() > 1) + return mem; + // Truncate linear memrefs to given size. return builder .create( - loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem, - ValueRange{}, ValueRange{sz}, ValueRange{}, + loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), + mem, ValueRange{}, ValueRange{sz}, ValueRange{}, ArrayRef{0}, // static offset ArrayRef{ShapedType::kDynamic}, // dynamic size ArrayRef{1}) // static stride diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 7d469198a653..025fd3331ba8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -785,45 +785,61 @@ public: } private: - // Helper to print contents of a single memref. Note that for the "push_back" - // vectors, this prints the full capacity, not just the size. This is done - // on purpose, so that clients see how much storage has been allocated in - // total. Contents of the extra capacity in the buffer may be uninitialized - // (unless the flag enable-buffer-initialization is set to true). + // Helper to print contents of a single memref. For "push_back" vectors, + // we assume that the previous getters for pos/crd/val have added a + // slice-to-size view to make sure we just print the size and not the + // full capacity. // - // Generates code to print: + // Generates code to print (1-dim or higher): // ( a0, a1, ... ) static void printContents(PatternRewriter &rewriter, Location loc, Value vec) { + auto shape = cast(vec.getType()).getShape(); + SmallVector idxs; + printContentsLevel(rewriter, loc, vec, 0, shape, idxs); + rewriter.create(loc, vector::PrintPunctuation::NewLine); + } + + // Helper to the helper. + static void printContentsLevel(PatternRewriter &rewriter, Location loc, + Value vec, unsigned i, ArrayRef shape, + SmallVectorImpl &idxs) { // Open bracket. rewriter.create(loc, vector::PrintPunctuation::Open); - // For loop over elements. + // Generate for loop. auto zero = constantIndex(rewriter, loc, 0); - auto size = rewriter.create(loc, vec, zero); + auto index = constantIndex(rewriter, loc, i); + auto size = rewriter.create(loc, vec, index); auto step = constantIndex(rewriter, loc, 1); auto forOp = rewriter.create(loc, zero, size, step); + idxs.push_back(forOp.getInductionVar()); rewriter.setInsertionPointToStart(forOp.getBody()); - auto idx = forOp.getInductionVar(); - auto val = rewriter.create(loc, vec, idx); - if (llvm::isa(val.getType())) { - // Since the vector dialect does not support complex types in any op, - // we split those into (real, imag) pairs here. - Value real = rewriter.create(loc, val); - Value imag = rewriter.create(loc, val); - rewriter.create(loc, vector::PrintPunctuation::Open); - rewriter.create(loc, real, - vector::PrintPunctuation::Comma); - rewriter.create(loc, imag, - vector::PrintPunctuation::Close); - rewriter.create(loc, vector::PrintPunctuation::Comma); + if (i < shape.size() - 1) { + // Enter deeper loop nest. + printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs); } else { - rewriter.create(loc, val, - vector::PrintPunctuation::Comma); + // Actual contents printing. + auto val = rewriter.create(loc, vec, idxs); + if (llvm::isa(val.getType())) { + // Since the vector dialect does not support complex types in any op, + // we split those into (real, imag) pairs here. + Value real = rewriter.create(loc, val); + Value imag = rewriter.create(loc, val); + rewriter.create(loc, vector::PrintPunctuation::Open); + rewriter.create(loc, real, + vector::PrintPunctuation::Comma); + rewriter.create(loc, imag, + vector::PrintPunctuation::Close); + rewriter.create(loc, vector::PrintPunctuation::Comma); + } else { + rewriter.create(loc, val, + vector::PrintPunctuation::Comma); + } } + idxs.pop_back(); rewriter.setInsertionPointAfter(forOp); - // Close bracket and end of line. + // Close bracket. rewriter.create(loc, vector::PrintPunctuation::Close); - rewriter.create(loc, vector::PrintPunctuation::NewLine); } // Helper method to print run-time lvl/dim sizes. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir index 20ae7e86285c..467a77f30777 100755 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir @@ -29,7 +29,7 @@ crdWidth = 32 }> -#BatchedCSR = #sparse_tensor.encoding<{ +#DenseCSR = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 32 @@ -42,7 +42,7 @@ }> // -// Test assembly operation with CCC, batched-CSR and CSR-dense. +// Test assembly operation with CCC, dense-CSR and CSR-dense. // module { // @@ -77,7 +77,7 @@ module { tensor<6xi64>, tensor<8xi32>), tensor<8xf32> to tensor<4x3x2xf32, #CCC> // - // Setup BatchedCSR. + // Setup DenseCSR. // %data1 = arith.constant dense< @@ -88,7 +88,7 @@ module { %crd1 = arith.constant dense< [ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32> - %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #BatchedCSR> + %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #DenseCSR> // // Setup CSRDense. @@ -137,7 +137,7 @@ module { // CHECK-NEXT: ---- // sparse_tensor.print %s0 : tensor<4x3x2xf32, #CCC> - sparse_tensor.print %s1 : tensor<4x3x2xf32, #BatchedCSR> + sparse_tensor.print %s1 : tensor<4x3x2xf32, #DenseCSR> sparse_tensor.print %s2 : tensor<4x3x2xf32, #CSRDense> // TODO: This check is no longer needed once the codegen path uses the @@ -148,7 +148,7 @@ module { // sparse_tensor.assemble copies buffers when running with the runtime // library. Deallocations are not needed when running in codegen mode. bufferization.dealloc_tensor %s0 : tensor<4x3x2xf32, #CCC> - bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #BatchedCSR> + bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #DenseCSR> bufferization.dealloc_tensor %s2 : tensor<4x3x2xf32, #CSRDense> } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir new file mode 100755 index 000000000000..98dee304fa51 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir @@ -0,0 +1,74 @@ +//-------------------------------------------------------------------------------------------------- +// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS. +// +// Set-up that's shared across all tests in this directory. In principle, this +// config could be moved to lit.local.cfg. However, there are downstream users that +// do not use these LIT config files. Hence why this is kept inline. +// +// DEFINE: %{sparsifier_opts} = enable-runtime-library=true +// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts} +// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}" +// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}" +// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils +// DEFINE: %{run_opts} = -e main -entry-point-result=void +// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs} +// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs} +// +// DEFINE: %{env} = +//-------------------------------------------------------------------------------------------------- + +// TODO: make this work with libgen + +// Do the same run, but now with direct IR generation. +// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true +// RUN: %{compile} | %{run} | FileCheck %s +// + +#BatchedCSR = #sparse_tensor.encoding<{ + map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed) +}> + +module { + + // + // Main driver that tests 3-D sparse tensor printing. + // + func.func @main() { + + %pos = arith.constant dense< + [[ 0, 8, 16, 24, 32], + [ 0, 8, 16, 24, 32]] + > : tensor<2x5xindex> + + %crd = arith.constant dense< + [[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]] + > : tensor<2x32xindex> + + %val = arith.constant dense< + [[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.], + [33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., + 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., + 55., 56., 57., 58., 59., 60., 61., 62., 63., 64.]] + > : tensor<2x32xf64> + + %X = sparse_tensor.assemble (%pos, %crd), %val + : (tensor<2x5xindex>, tensor<2x32xindex>), tensor<2x32xf64> to tensor<2x4x8xf64, #BatchedCSR> + + // CHECK: ---- Sparse Tensor ---- + // CHECK-NEXT: nse = 32 + // CHECK-NEXT: dim = ( 2, 4, 8 ) + // CHECK-NEXT: lvl = ( 2, 4, 8 ) + // CHECK-NEXT: pos[2] : ( ( 0, 8, 16, 24, 32, )( 0, 8, 16, 24, 32, ) ) + // CHECK-NEXT: crd[2] : ( ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ) + // CHECK-SAME: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ) ) + // CHECK-NEXT: values : ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, ) + // CHECK-SAME: ( 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, ) ) + // CHECK-NEXT: ---- + sparse_tensor.print %X : tensor<2x4x8xf64, #BatchedCSR> + + return + } +}