[mlir][sparse] Adding IsSparseTensorPred and updating ops to use it

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D126994
This commit is contained in:
wren romano
2022-06-03 16:41:02 -07:00
parent 604016dbe4
commit 3cf03f1c56
4 changed files with 58 additions and 93 deletions

View File

@@ -93,4 +93,28 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
}];
}
def IsSparseTensorPred
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;
// The following four follow the same idiom as `TensorOf`, `AnyTensor`,
// `RankedTensorOf`, `AnyRankedTensor`.
class SparseTensorOf<list<Type> allowedTypes>
: ShapedContainerType<
allowedTypes,
And<[IsTensorTypePred, IsSparseTensorPred]>,
"sparse tensor",
"::mlir::TensorType">;
def AnySparseTensor : SparseTensorOf<[AnyType]>;
class RankedSparseTensorOf<list<Type> allowedTypes>
: ShapedContainerType<
allowedTypes,
And<[IsTensorTypePred, HasRankPred, IsSparseTensorPred]>,
"ranked sparse tensor",
"::mlir::TensorType">;
def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
#endif // SPARSETENSOR_ATTRDEFS

View File

@@ -27,7 +27,7 @@ class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
Arguments<(ins AnyType:$source)>,
Results<(outs TensorOf<[AnyType]>:$result)> {
Results<(outs AnySparseTensor:$result)> {
string summary = "Materializes a new sparse tensor from given source";
string description = [{
Materializes a sparse tensor with contents taken from an opaque pointer
@@ -46,7 +46,6 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
```
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
let hasVerifier = 1;
}
def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
@@ -92,7 +91,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
}
def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
let summary = "Extracts pointers array at given dimension from a tensor";
let description = [{
@@ -117,7 +116,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
}
def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
let summary = "Extracts indices array at given dimension from a tensor";
let description = [{
@@ -142,7 +141,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
}
def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
Arguments<(ins AnyTensor:$tensor)>,
Arguments<(ins AnySparseTensor:$tensor)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
let summary = "Extracts numerical values array from a tensor";
let description = [{
@@ -173,7 +172,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
//===----------------------------------------------------------------------===//
def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
Arguments<(ins AnyTensor:$tensor,
Arguments<(ins AnySparseTensor:$tensor,
StridedMemRefRankOf<[Index], [1]>:$indices,
AnyType:$value)> {
string summary = "Inserts a value into given sparse tensor in lexicographical index order";
@@ -196,11 +195,10 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
}];
let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
" type($tensor) `,` type($indices) `,` type($value)";
let hasVerifier = 1;
}
def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
Arguments<(ins AnyTensor:$tensor)>,
Arguments<(ins AnySparseTensor:$tensor)>,
Results<(outs AnyStridedMemRefOfRank<1>:$values,
StridedMemRefRankOf<[I1],[1]>:$filled,
StridedMemRefRankOf<[Index],[1]>:$added,
@@ -238,11 +236,10 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
" `,` type($filled) `,` type($added) `,` type($count)";
let hasVerifier = 1;
}
def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
Arguments<(ins AnyTensor:$tensor,
Arguments<(ins AnySparseTensor:$tensor,
StridedMemRefRankOf<[Index],[1]>:$indices,
AnyStridedMemRefOfRank<1>:$values,
StridedMemRefRankOf<[I1],[1]>:$filled,
@@ -273,11 +270,10 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
" $added `,` $count attr-dict `:` type($tensor) `,`"
" type($indices) `,` type($values) `,` type($filled) `,`"
" type($added) `,` type($count)";
let hasVerifier = 1;
}
def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
Arguments<(ins AnyTensor:$tensor, UnitAttr:$hasInserts)>,
Arguments<(ins AnySparseTensor:$tensor, UnitAttr:$hasInserts)>,
Results<(outs AnyTensor:$result)> {
let summary =
"Rematerializes tensor from underlying sparse storage format";
@@ -306,11 +302,10 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
```
}];
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
Arguments<(ins AnyTensor:$tensor)> {
Arguments<(ins AnySparseTensor:$tensor)> {
string summary = "Releases underlying sparse storage format of given tensor";
string description = [{
Releases the underlying sparse storage format for a tensor that
@@ -332,11 +327,10 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
```
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
Arguments<(ins AnyType:$tensor, AnyType:$dest)> {
Arguments<(ins AnySparseTensor:$tensor, AnyType:$dest)> {
string summary = "Outputs a sparse tensor to the given destination";
string description = [{
Outputs the contents of a sparse tensor to the destination defined by an
@@ -353,7 +347,6 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
```
}];
let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//

View File

@@ -208,12 +208,6 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
return failure();
}
LogicalResult NewOp::verify() {
if (!getSparseTensorEncoding(result().getType()))
return emitError("expected a sparse tensor result");
return success();
}
LogicalResult ConvertOp::verify() {
if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
@@ -240,30 +234,24 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
}
LogicalResult ToPointersOp::verify() {
if (auto e = getSparseTensorEncoding(tensor().getType())) {
if (failed(isInBounds(dim(), tensor())))
return emitError("requested pointers dimension out of bounds");
if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
return emitError("unexpected type for pointers");
return success();
}
return emitError("expected a sparse tensor to get pointers");
auto e = getSparseTensorEncoding(tensor().getType());
if (failed(isInBounds(dim(), tensor())))
return emitError("requested pointers dimension out of bounds");
if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
return emitError("unexpected type for pointers");
return success();
}
LogicalResult ToIndicesOp::verify() {
if (auto e = getSparseTensorEncoding(tensor().getType())) {
if (failed(isInBounds(dim(), tensor())))
return emitError("requested indices dimension out of bounds");
if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
return emitError("unexpected type for indices");
return success();
}
return emitError("expected a sparse tensor to get indices");
auto e = getSparseTensorEncoding(tensor().getType());
if (failed(isInBounds(dim(), tensor())))
return emitError("requested indices dimension out of bounds");
if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
return emitError("unexpected type for indices");
return success();
}
LogicalResult ToValuesOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor to get values");
RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
MemRefType mtp = result().getType().cast<MemRefType>();
if (ttp.getElementType() != mtp.getElementType())
@@ -271,46 +259,6 @@ LogicalResult ToValuesOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// TensorDialect Management Operations.
//===----------------------------------------------------------------------===//
LogicalResult LexInsertOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor for insertion");
return success();
}
LogicalResult ExpandOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor for expansion");
return success();
}
LogicalResult CompressOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor for compression");
return success();
}
LogicalResult LoadOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor to materialize");
return success();
}
LogicalResult ReleaseOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor to release");
return success();
}
LogicalResult OutOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor for output");
return success();
}
//===----------------------------------------------------------------------===//
// TensorDialect Linalg.Generic Operations.
//===----------------------------------------------------------------------===//

View File

@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
// expected-error@+1 {{expected a sparse tensor result}}
// expected-error@+1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any type values, but got 'tensor<32xf32>'}}
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<32xf32>
return %0 : tensor<32xf32>
}
@@ -9,7 +9,7 @@ func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
// -----
func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
// expected-error@+1 {{expected a sparse tensor to release}}
// expected-error@+1 {{'sparse_tensor.release' op operand #0 must be sparse tensor of any type values, but got 'tensor<4xi32>'}}
sparse_tensor.release %arg0 : tensor<4xi32>
return
}
@@ -18,7 +18,7 @@ func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
%c = arith.constant 0 : index
// expected-error@+1 {{expected a sparse tensor to get pointers}}
// expected-error@+1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
%0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -27,7 +27,7 @@ func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
func.func @invalid_pointers_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
%c = arith.constant 0 : index
// expected-error@+1 {{expected a sparse tensor to get pointers}}
// expected-error@+1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
%0 = sparse_tensor.pointers %arg0, %c : tensor<*xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -58,7 +58,7 @@ func.func @pointers_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex
func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
%c = arith.constant 1 : index
// expected-error@+1 {{expected a sparse tensor to get indices}}
// expected-error@+1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<10x10xi32>'}}
%0 = sparse_tensor.indices %arg0, %c : tensor<10x10xi32> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -67,7 +67,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
%c = arith.constant 0 : index
// expected-error@+1 {{expected a sparse tensor to get indices}}
// expected-error@+1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
%0 = sparse_tensor.indices %arg0, %c : tensor<*xf64> to memref<?xindex>
return %0 : memref<?xindex>
}
@@ -97,7 +97,7 @@ func.func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex>
// -----
func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
// expected-error@+1 {{expected a sparse tensor to get values}}
// expected-error@+1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any type values, but got 'tensor<1024xf32>'}}
%0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref<?xf32>
return %0 : memref<?xf32>
}
@@ -115,7 +115,7 @@ func.func @mismatch_values_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<
// -----
func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> {
// expected-error@+1 {{expected a sparse tensor to materialize}}
// expected-error@+1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any type values, but got 'tensor<16x32xf64>'}}
%0 = sparse_tensor.load %arg0 : tensor<16x32xf64>
return %0 : tensor<16x32xf64>
}
@@ -123,7 +123,7 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64
// -----
func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xindex>, %arg2: f64) {
// expected-error@+1 {{expected a sparse tensor for insertion}}
// expected-error@+1 {{'sparse_tensor.lex_insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref<?xindex>, f64
return
}
@@ -131,7 +131,7 @@ func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xind
// -----
func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
// expected-error@+1 {{expected a sparse tensor for expansion}}
// expected-error@+1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
%values, %filled, %added, %count = sparse_tensor.expand %arg0
: tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
return
@@ -142,7 +142,7 @@ func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
func.func @sparse_unannotated_compression(%arg0: tensor<128xf64>, %arg1: memref<?xindex>,
%arg2: memref<?xf64>, %arg3: memref<?xi1>,
%arg4: memref<?xindex>, %arg5: index) {
// expected-error@+1 {{expected a sparse tensor for compression}}
// expected-error@+1 {{'sparse_tensor.compress' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
: tensor<128xf64>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
}
@@ -178,7 +178,7 @@ func.func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10x
// -----
func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr<i8>) {
// expected-error@+1 {{expected a sparse tensor for output}}
// expected-error@+1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any type values, but got 'tensor<10xf64>'}}
sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr<i8>
return
}