[mlir][sparse] rename lex_insert into insert

This change goes not impact any semantics yet, but it
is in preparation for implementing the unordered and not-unique
properties. Changing lex_insert to insert is a first step.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D133531
This commit is contained in:
Aart Bik
2022-09-08 14:41:18 -07:00
parent 579a5a47a9
commit f76dcede3f
12 changed files with 35 additions and 26 deletions

View File

@@ -49,9 +49,9 @@ def SparseTensor_Dialect : Dialect {
lattices drive actual sparse code generation, which consists of a
relatively straightforward one-to-one mapping from iteration lattices
to combinations of for-loops, while-loops, and if-statements. Sparse
tensor outputs that materialize uninitialized are handled with
insertions in pure lexicographical index order if all parallel loops
are outermost or using a 1-dimensional access pattern expansion
tensor outputs that materialize uninitialized are handled with direct
insertions if all parallel loops are outermost or insertions that
indirectly go through a 1-dimensional access pattern expansion
(a.k.a. workspace) where feasible [Gustavson72,Bik96,Kjolstad19].
* [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.

View File

@@ -197,18 +197,25 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", []>,
// as our sparse abstractions evolve.
//===----------------------------------------------------------------------===//
def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
def SparseTensor_InsertOp : SparseTensor_Op<"insert", []>,
Arguments<(ins AnySparseTensor:$tensor,
StridedMemRefRankOf<[Index], [1]>:$indices,
AnyType:$value)> {
string summary = "Inserts a value into given sparse tensor in lexicographical index order";
string summary = "Inserts a value into given sparse tensor";
string description = [{
Inserts the given value at given indices into the underlying sparse
storage format of the given tensor with the given indices. This
operation can only be applied when a tensor materializes unintialized
with a `bufferization.alloc_tensor` operation, the insertions occur in
strict lexicographical index order, and the final tensor is constructed
with a `load` operation that has the `hasInserts` attribute set.
with a `bufferization.alloc_tensor` operation and the final tensor
is constructed with a `load` operation that has the `hasInserts`
attribute set.
Properties in the sparse tensor type fully describe what kind
of insertion order is allowed. When all dimensions have "unique"
and "ordered" properties, for example, insertions should occur in
strict lexicographical index order. Other properties define
different insertion regimens. Inserting in a way contrary to
these properties results in undefined behavior.
Note that this operation is "impure" in the sense that its behavior
is solely defined by side-effects and not SSA values. The semantics
@@ -217,7 +224,7 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
Example:
```mlir
sparse_tensor.lex_insert %tensor, %indices, %val
sparse_tensor.insert %tensor, %indices, %val
: tensor<1024x1024xf64, #CSR>, memref<?xindex>, memref<f64>
```
}];

View File

@@ -1137,13 +1137,15 @@ public:
}
};
/// Sparse conversion rule for inserting in lexicographic index order.
class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
/// Sparse conversion rule for the insertion operator.
class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(LexInsertOp op, OpAdaptor adaptor,
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Note that the current regime only allows for strict lexicographic
// index order.
Type elemTp = op.getTensor().getType().cast<ShapedType>().getElementType();
SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
replaceOpWithFuncCall(rewriter, op, name, {}, adaptor.getOperands(),
@@ -1432,7 +1434,7 @@ void mlir::populateSparseTensorConversionPatterns(
SparseTensorConcatConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorToPointersConverter,
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
SparseTensorLoadConverter, SparseTensorLexInsertConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorOutConverter>(typeConverter, patterns.getContext());

View File

@@ -762,7 +762,7 @@ static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
// Direct insertion in lexicographic index order.
if (!codegen.expValues) {
builder.create<memref::StoreOp>(loc, rhs, codegen.lexVal);
builder.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, codegen.lexVal);
builder.create<InsertOp>(loc, t->get(), codegen.lexIdx, codegen.lexVal);
return;
}
// Generates insertion code along expanded access pattern.

View File

@@ -494,7 +494,7 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
%arg1: memref<?xindex>,
%arg2: memref<f32>) {
sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf32, #SparseVector>, memref<?xindex>, memref<f32>
sparse_tensor.insert %arg0, %arg1, %arg2 : tensor<128xf32, #SparseVector>, memref<?xindex>, memref<f32>
return
}

View File

@@ -107,8 +107,8 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64
// -----
func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xindex>, %arg2: f64) {
// expected-error@+1 {{'sparse_tensor.lex_insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref<?xindex>, f64
// expected-error@+1 {{'sparse_tensor.insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
sparse_tensor.insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref<?xindex>, f64
return
}

View File

@@ -123,10 +123,10 @@ func.func @sparse_load_ins(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x
// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-SAME: %[[B:.*]]: memref<?xindex>,
// CHECK-SAME: %[[C:.*]]: f64) {
// CHECK: sparse_tensor.lex_insert %[[A]], %[[B]], %[[C]] : tensor<128xf64, #{{.*}}>, memref<?xindex>, f64
// CHECK: sparse_tensor.insert %[[A]], %[[B]], %[[C]] : tensor<128xf64, #{{.*}}>, memref<?xindex>, f64
// CHECK: return
func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: memref<?xindex>, %arg2: f64) {
sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64, #SparseVector>, memref<?xindex>, f64
sparse_tensor.insert %arg0, %arg1, %arg2 : tensor<128xf64, #SparseVector>, memref<?xindex>, f64
return
}

View File

@@ -375,7 +375,7 @@ func.func @divbyc(%arga: tensor<32xf64, #SV>,
// CHECK: %[[VAL_20:.*]] = math.sin %[[VAL_19]] : f64
// CHECK: %[[VAL_21:.*]] = math.tanh %[[VAL_20]] : f64
// CHECK: memref.store %[[VAL_21]], %[[BUF]][] : memref<f64>
// CHECK: sparse_tensor.lex_insert %[[VAL_4]], %[[VAL_8]], %[[BUF]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref<?xindex>, memref<f64>
// CHECK: sparse_tensor.insert %[[VAL_4]], %[[VAL_8]], %[[BUF]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref<?xindex>, memref<f64>
// CHECK: }
// CHECK: %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: return %[[VAL_22]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
@@ -419,7 +419,7 @@ func.func @zero_preserving_math(%arga: tensor<32xf64, #SV>) -> tensor<32xf64, #S
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xcomplex<f64>>
// CHECK: %[[VAL_15:.*]] = complex.div %[[VAL_14]], %[[VAL_3]] : complex<f64>
// CHECK: memref.store %[[VAL_15]], %[[VAL_9]][] : memref<complex<f64>>
// CHECK: sparse_tensor.lex_insert %[[VAL_4]], %[[VAL_8]], %[[VAL_9]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>>, memref<?xindex>, memref<complex<f64>>
// CHECK: sparse_tensor.insert %[[VAL_4]], %[[VAL_8]], %[[VAL_9]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>>, memref<?xindex>, memref<complex<f64>>
// CHECK: }
// CHECK: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>>
// CHECK: return %[[VAL_16]] : tensor<32xcomplex<f64>, #sparse_tensor.encoding<{{.*}}>>

View File

@@ -100,7 +100,7 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_24]] : i64
// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_22]], %[[VAL_25]] : i64
// CHECK: memref.store %[[VAL_26]], %[[BUF]][] : memref<i64>
// CHECK: sparse_tensor.lex_insert %[[VAL_6]], %[[VAL_12]], %[[BUF]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK: sparse_tensor.insert %[[VAL_6]], %[[VAL_12]], %[[BUF]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK: }
// CHECK: }
// CHECK: %[[VAL_27:.*]] = sparse_tensor.load %[[VAL_6]] hasInserts : tensor<?x?xi64, #sparse_tensor.encoding

View File

@@ -123,7 +123,7 @@ func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR>) -> tensor<32x
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<?xf32>
// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_18]], %[[VAL_1]] : f32
// CHECK: memref.store %[[VAL_19]], %[[BUF]][] : memref<f32>
// CHECK: sparse_tensor.lex_insert %[[VAL_7]], %[[VAL_11]], %[[BUF]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
// CHECK: sparse_tensor.insert %[[VAL_7]], %[[VAL_11]], %[[BUF]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
// CHECK: }
// CHECK: }
// CHECK: %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_7]] hasInserts : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
@@ -259,7 +259,7 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
// CHECK: scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_98:.*]] : index, index, i32
// CHECK: }
// CHECK: memref.store %[[VAL_70]]#2, %[[BUF]][] : memref<i32>
// CHECK: sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[BUF]] : tensor<?x?xi32, #{{.*}}>, memref<?xindex>, memref<i32>
// CHECK: sparse_tensor.insert %[[VAL_8]], %[[VAL_23]], %[[BUF]] : tensor<?x?xi32, #{{.*}}>, memref<?xindex>, memref<i32>
// CHECK: } else {
// CHECK: }
// CHECK: %[[VAL_100:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index

View File

@@ -166,7 +166,7 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
// CHECK: scf.yield %[[VAL_37]] : f64
// CHECK: }
// CHECK: memref.store %[[VAL_30:.*]], %[[VAL_19]][] : memref<f64>
// CHECK: sparse_tensor.lex_insert %[[VAL_10]], %[[VAL_18]], %[[VAL_19]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, memref<?xindex>, memref<f64>
// CHECK: sparse_tensor.insert %[[VAL_10]], %[[VAL_18]], %[[VAL_19]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, memref<?xindex>, memref<f64>
// CHECK: }
// CHECK: }
// CHECK: %[[VAL_39:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>

View File

@@ -42,7 +42,7 @@
// CHECK: memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex>
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xf64>
// CHECK: memref.store %[[VAL_22]], %[[VAL_12]][] : memref<f64>
// CHECK: sparse_tensor.lex_insert %[[VAL_4]], %[[VAL_11]], %[[VAL_12]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>, memref<?xindex>, memref<f64>
// CHECK: sparse_tensor.insert %[[VAL_4]], %[[VAL_11]], %[[VAL_12]] : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>, memref<?xindex>, memref<f64>
// CHECK: }
// CHECK: }
// CHECK: %[[VAL_23:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<4x3xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>