mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
[mlir][sparse] support 'batch' dimensions in sparse_tensor.print (#91411)
This commit is contained in:
@@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc,
|
||||
/// Generates a subview into the sizes.
|
||||
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
|
||||
Value sz) {
|
||||
auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
|
||||
auto memTp = llvm::cast<MemRefType>(mem.getType());
|
||||
// For higher-dimensional memrefs, we assume that the innermost
|
||||
// dimension is always of the right size.
|
||||
// TODO: generate complex truncating view here too?
|
||||
if (memTp.getRank() > 1)
|
||||
return mem;
|
||||
// Truncate linear memrefs to given size.
|
||||
return builder
|
||||
.create<memref::SubViewOp>(
|
||||
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
|
||||
ValueRange{}, ValueRange{sz}, ValueRange{},
|
||||
loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
|
||||
mem, ValueRange{}, ValueRange{sz}, ValueRange{},
|
||||
ArrayRef<int64_t>{0}, // static offset
|
||||
ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
|
||||
ArrayRef<int64_t>{1}) // static stride
|
||||
|
||||
@@ -785,45 +785,61 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
// Helper to print contents of a single memref. Note that for the "push_back"
|
||||
// vectors, this prints the full capacity, not just the size. This is done
|
||||
// on purpose, so that clients see how much storage has been allocated in
|
||||
// total. Contents of the extra capacity in the buffer may be uninitialized
|
||||
// (unless the flag enable-buffer-initialization is set to true).
|
||||
// Helper to print contents of a single memref. For "push_back" vectors,
|
||||
// we assume that the previous getters for pos/crd/val have added a
|
||||
// slice-to-size view to make sure we just print the size and not the
|
||||
// full capacity.
|
||||
//
|
||||
// Generates code to print:
|
||||
// Generates code to print (1-dim or higher):
|
||||
// ( a0, a1, ... )
|
||||
static void printContents(PatternRewriter &rewriter, Location loc,
|
||||
Value vec) {
|
||||
auto shape = cast<ShapedType>(vec.getType()).getShape();
|
||||
SmallVector<Value> idxs;
|
||||
printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
|
||||
}
|
||||
|
||||
// Helper to the helper.
|
||||
static void printContentsLevel(PatternRewriter &rewriter, Location loc,
|
||||
Value vec, unsigned i, ArrayRef<int64_t> shape,
|
||||
SmallVectorImpl<Value> &idxs) {
|
||||
// Open bracket.
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
|
||||
// For loop over elements.
|
||||
// Generate for loop.
|
||||
auto zero = constantIndex(rewriter, loc, 0);
|
||||
auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
|
||||
auto index = constantIndex(rewriter, loc, i);
|
||||
auto size = rewriter.create<memref::DimOp>(loc, vec, index);
|
||||
auto step = constantIndex(rewriter, loc, 1);
|
||||
auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
|
||||
idxs.push_back(forOp.getInductionVar());
|
||||
rewriter.setInsertionPointToStart(forOp.getBody());
|
||||
auto idx = forOp.getInductionVar();
|
||||
auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
|
||||
if (llvm::isa<ComplexType>(val.getType())) {
|
||||
// Since the vector dialect does not support complex types in any op,
|
||||
// we split those into (real, imag) pairs here.
|
||||
Value real = rewriter.create<complex::ReOp>(loc, val);
|
||||
Value imag = rewriter.create<complex::ImOp>(loc, val);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
|
||||
rewriter.create<vector::PrintOp>(loc, real,
|
||||
vector::PrintPunctuation::Comma);
|
||||
rewriter.create<vector::PrintOp>(loc, imag,
|
||||
vector::PrintPunctuation::Close);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
|
||||
if (i < shape.size() - 1) {
|
||||
// Enter deeper loop nest.
|
||||
printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
|
||||
} else {
|
||||
rewriter.create<vector::PrintOp>(loc, val,
|
||||
vector::PrintPunctuation::Comma);
|
||||
// Actual contents printing.
|
||||
auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
|
||||
if (llvm::isa<ComplexType>(val.getType())) {
|
||||
// Since the vector dialect does not support complex types in any op,
|
||||
// we split those into (real, imag) pairs here.
|
||||
Value real = rewriter.create<complex::ReOp>(loc, val);
|
||||
Value imag = rewriter.create<complex::ImOp>(loc, val);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
|
||||
rewriter.create<vector::PrintOp>(loc, real,
|
||||
vector::PrintPunctuation::Comma);
|
||||
rewriter.create<vector::PrintOp>(loc, imag,
|
||||
vector::PrintPunctuation::Close);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
|
||||
} else {
|
||||
rewriter.create<vector::PrintOp>(loc, val,
|
||||
vector::PrintPunctuation::Comma);
|
||||
}
|
||||
}
|
||||
idxs.pop_back();
|
||||
rewriter.setInsertionPointAfter(forOp);
|
||||
// Close bracket and end of line.
|
||||
// Close bracket.
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
|
||||
}
|
||||
|
||||
// Helper method to print run-time lvl/dim sizes.
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
crdWidth = 32
|
||||
}>
|
||||
|
||||
#BatchedCSR = #sparse_tensor.encoding<{
|
||||
#DenseCSR = #sparse_tensor.encoding<{
|
||||
map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed),
|
||||
posWidth = 64,
|
||||
crdWidth = 32
|
||||
@@ -42,7 +42,7 @@
|
||||
}>
|
||||
|
||||
//
|
||||
// Test assembly operation with CCC, batched-CSR and CSR-dense.
|
||||
// Test assembly operation with CCC, dense-CSR and CSR-dense.
|
||||
//
|
||||
module {
|
||||
//
|
||||
@@ -77,7 +77,7 @@ module {
|
||||
tensor<6xi64>, tensor<8xi32>), tensor<8xf32> to tensor<4x3x2xf32, #CCC>
|
||||
|
||||
//
|
||||
// Setup BatchedCSR.
|
||||
// Setup DenseCSR.
|
||||
//
|
||||
|
||||
%data1 = arith.constant dense<
|
||||
@@ -88,7 +88,7 @@ module {
|
||||
%crd1 = arith.constant dense<
|
||||
[ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32>
|
||||
|
||||
%s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #BatchedCSR>
|
||||
%s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #DenseCSR>
|
||||
|
||||
//
|
||||
// Setup CSRDense.
|
||||
@@ -137,7 +137,7 @@ module {
|
||||
// CHECK-NEXT: ----
|
||||
//
|
||||
sparse_tensor.print %s0 : tensor<4x3x2xf32, #CCC>
|
||||
sparse_tensor.print %s1 : tensor<4x3x2xf32, #BatchedCSR>
|
||||
sparse_tensor.print %s1 : tensor<4x3x2xf32, #DenseCSR>
|
||||
sparse_tensor.print %s2 : tensor<4x3x2xf32, #CSRDense>
|
||||
|
||||
// TODO: This check is no longer needed once the codegen path uses the
|
||||
@@ -148,7 +148,7 @@ module {
|
||||
// sparse_tensor.assemble copies buffers when running with the runtime
|
||||
// library. Deallocations are not needed when running in codegen mode.
|
||||
bufferization.dealloc_tensor %s0 : tensor<4x3x2xf32, #CCC>
|
||||
bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #BatchedCSR>
|
||||
bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #DenseCSR>
|
||||
bufferization.dealloc_tensor %s2 : tensor<4x3x2xf32, #CSRDense>
|
||||
}
|
||||
|
||||
|
||||
74
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir
Executable file
74
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir
Executable file
@@ -0,0 +1,74 @@
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
|
||||
//
|
||||
// Set-up that's shared across all tests in this directory. In principle, this
|
||||
// config could be moved to lit.local.cfg. However, there are downstream users that
|
||||
// do not use these LIT config files. Hence why this is kept inline.
|
||||
//
|
||||
// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
|
||||
// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
|
||||
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
|
||||
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
|
||||
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
|
||||
// DEFINE: %{run_opts} = -e main -entry-point-result=void
|
||||
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
|
||||
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
|
||||
//
|
||||
// DEFINE: %{env} =
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
||||
// TODO: make this work with libgen
|
||||
|
||||
// Do the same run, but now with direct IR generation.
|
||||
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
|
||||
// RUN: %{compile} | %{run} | FileCheck %s
|
||||
//
|
||||
|
||||
#BatchedCSR = #sparse_tensor.encoding<{
|
||||
map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed)
|
||||
}>
|
||||
|
||||
module {
|
||||
|
||||
//
|
||||
// Main driver that tests 3-D sparse tensor printing.
|
||||
//
|
||||
func.func @main() {
|
||||
|
||||
%pos = arith.constant dense<
|
||||
[[ 0, 8, 16, 24, 32],
|
||||
[ 0, 8, 16, 24, 32]]
|
||||
> : tensor<2x5xindex>
|
||||
|
||||
%crd = arith.constant dense<
|
||||
[[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]]
|
||||
> : tensor<2x32xindex>
|
||||
|
||||
%val = arith.constant dense<
|
||||
[[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.,
|
||||
12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22.,
|
||||
23., 24., 25., 26., 27., 28., 29., 30., 31., 32.],
|
||||
[33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,
|
||||
44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54.,
|
||||
55., 56., 57., 58., 59., 60., 61., 62., 63., 64.]]
|
||||
> : tensor<2x32xf64>
|
||||
|
||||
%X = sparse_tensor.assemble (%pos, %crd), %val
|
||||
: (tensor<2x5xindex>, tensor<2x32xindex>), tensor<2x32xf64> to tensor<2x4x8xf64, #BatchedCSR>
|
||||
|
||||
// CHECK: ---- Sparse Tensor ----
|
||||
// CHECK-NEXT: nse = 32
|
||||
// CHECK-NEXT: dim = ( 2, 4, 8 )
|
||||
// CHECK-NEXT: lvl = ( 2, 4, 8 )
|
||||
// CHECK-NEXT: pos[2] : ( ( 0, 8, 16, 24, 32, )( 0, 8, 16, 24, 32, ) )
|
||||
// CHECK-NEXT: crd[2] : ( ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, )
|
||||
// CHECK-SAME: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ) )
|
||||
// CHECK-NEXT: values : ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, )
|
||||
// CHECK-SAME: ( 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, ) )
|
||||
// CHECK-NEXT: ----
|
||||
sparse_tensor.print %X : tensor<2x4x8xf64, #BatchedCSR>
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user