mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 01:15:50 +08:00
[mlir][sparse] Adding IsSparseTensorPred and updating ops to use it
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D126994
This commit is contained in:
@@ -93,4 +93,28 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
|
||||
}];
|
||||
}
|
||||
|
||||
def IsSparseTensorPred
|
||||
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;
|
||||
|
||||
// The following four follow the same idiom as `TensorOf`, `AnyTensor`,
|
||||
// `RankedTensorOf`, `AnyRankedTensor`.
|
||||
|
||||
class SparseTensorOf<list<Type> allowedTypes>
|
||||
: ShapedContainerType<
|
||||
allowedTypes,
|
||||
And<[IsTensorTypePred, IsSparseTensorPred]>,
|
||||
"sparse tensor",
|
||||
"::mlir::TensorType">;
|
||||
|
||||
def AnySparseTensor : SparseTensorOf<[AnyType]>;
|
||||
|
||||
class RankedSparseTensorOf<list<Type> allowedTypes>
|
||||
: ShapedContainerType<
|
||||
allowedTypes,
|
||||
And<[IsTensorTypePred, HasRankPred, IsSparseTensorPred]>,
|
||||
"ranked sparse tensor",
|
||||
"::mlir::TensorType">;
|
||||
|
||||
def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
|
||||
|
||||
#endif // SPARSETENSOR_ATTRDEFS
|
||||
|
||||
@@ -27,7 +27,7 @@ class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
|
||||
|
||||
def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
|
||||
Arguments<(ins AnyType:$source)>,
|
||||
Results<(outs TensorOf<[AnyType]>:$result)> {
|
||||
Results<(outs AnySparseTensor:$result)> {
|
||||
string summary = "Materializes a new sparse tensor from given source";
|
||||
string description = [{
|
||||
Materializes a sparse tensor with contents taken from an opaque pointer
|
||||
@@ -46,7 +46,6 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
|
||||
```
|
||||
}];
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
|
||||
@@ -92,7 +91,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
|
||||
}
|
||||
|
||||
def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
|
||||
Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
|
||||
Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
|
||||
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
|
||||
let summary = "Extracts pointers array at given dimension from a tensor";
|
||||
let description = [{
|
||||
@@ -117,7 +116,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
|
||||
}
|
||||
|
||||
def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
|
||||
Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
|
||||
Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>,
|
||||
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
|
||||
let summary = "Extracts indices array at given dimension from a tensor";
|
||||
let description = [{
|
||||
@@ -142,7 +141,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
|
||||
}
|
||||
|
||||
def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
|
||||
Arguments<(ins AnyTensor:$tensor)>,
|
||||
Arguments<(ins AnySparseTensor:$tensor)>,
|
||||
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
|
||||
let summary = "Extracts numerical values array from a tensor";
|
||||
let description = [{
|
||||
@@ -173,7 +172,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
|
||||
Arguments<(ins AnyTensor:$tensor,
|
||||
Arguments<(ins AnySparseTensor:$tensor,
|
||||
StridedMemRefRankOf<[Index], [1]>:$indices,
|
||||
AnyType:$value)> {
|
||||
string summary = "Inserts a value into given sparse tensor in lexicographical index order";
|
||||
@@ -196,11 +195,10 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
|
||||
}];
|
||||
let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
|
||||
" type($tensor) `,` type($indices) `,` type($value)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
|
||||
Arguments<(ins AnyTensor:$tensor)>,
|
||||
Arguments<(ins AnySparseTensor:$tensor)>,
|
||||
Results<(outs AnyStridedMemRefOfRank<1>:$values,
|
||||
StridedMemRefRankOf<[I1],[1]>:$filled,
|
||||
StridedMemRefRankOf<[Index],[1]>:$added,
|
||||
@@ -238,11 +236,10 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
|
||||
}];
|
||||
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
|
||||
" `,` type($filled) `,` type($added) `,` type($count)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
|
||||
Arguments<(ins AnyTensor:$tensor,
|
||||
Arguments<(ins AnySparseTensor:$tensor,
|
||||
StridedMemRefRankOf<[Index],[1]>:$indices,
|
||||
AnyStridedMemRefOfRank<1>:$values,
|
||||
StridedMemRefRankOf<[I1],[1]>:$filled,
|
||||
@@ -273,11 +270,10 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
|
||||
" $added `,` $count attr-dict `:` type($tensor) `,`"
|
||||
" type($indices) `,` type($values) `,` type($filled) `,`"
|
||||
" type($added) `,` type($count)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
|
||||
Arguments<(ins AnyTensor:$tensor, UnitAttr:$hasInserts)>,
|
||||
Arguments<(ins AnySparseTensor:$tensor, UnitAttr:$hasInserts)>,
|
||||
Results<(outs AnyTensor:$result)> {
|
||||
let summary =
|
||||
"Rematerializes tensor from underlying sparse storage format";
|
||||
@@ -306,11 +302,10 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
|
||||
```
|
||||
}];
|
||||
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
|
||||
Arguments<(ins AnyTensor:$tensor)> {
|
||||
Arguments<(ins AnySparseTensor:$tensor)> {
|
||||
string summary = "Releases underlying sparse storage format of given tensor";
|
||||
string description = [{
|
||||
Releases the underlying sparse storage format for a tensor that
|
||||
@@ -332,11 +327,10 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
|
||||
```
|
||||
}];
|
||||
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
|
||||
Arguments<(ins AnyType:$tensor, AnyType:$dest)> {
|
||||
Arguments<(ins AnySparseTensor:$tensor, AnyType:$dest)> {
|
||||
string summary = "Outputs a sparse tensor to the given destination";
|
||||
string description = [{
|
||||
Outputs the contents of a sparse tensor to the destination defined by an
|
||||
@@ -353,7 +347,6 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
|
||||
```
|
||||
}];
|
||||
let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -208,12 +208,6 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult NewOp::verify() {
|
||||
if (!getSparseTensorEncoding(result().getType()))
|
||||
return emitError("expected a sparse tensor result");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertOp::verify() {
|
||||
if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
|
||||
@@ -240,30 +234,24 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
||||
}
|
||||
|
||||
LogicalResult ToPointersOp::verify() {
|
||||
if (auto e = getSparseTensorEncoding(tensor().getType())) {
|
||||
if (failed(isInBounds(dim(), tensor())))
|
||||
return emitError("requested pointers dimension out of bounds");
|
||||
if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
|
||||
return emitError("unexpected type for pointers");
|
||||
return success();
|
||||
}
|
||||
return emitError("expected a sparse tensor to get pointers");
|
||||
auto e = getSparseTensorEncoding(tensor().getType());
|
||||
if (failed(isInBounds(dim(), tensor())))
|
||||
return emitError("requested pointers dimension out of bounds");
|
||||
if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
|
||||
return emitError("unexpected type for pointers");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ToIndicesOp::verify() {
|
||||
if (auto e = getSparseTensorEncoding(tensor().getType())) {
|
||||
if (failed(isInBounds(dim(), tensor())))
|
||||
return emitError("requested indices dimension out of bounds");
|
||||
if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
|
||||
return emitError("unexpected type for indices");
|
||||
return success();
|
||||
}
|
||||
return emitError("expected a sparse tensor to get indices");
|
||||
auto e = getSparseTensorEncoding(tensor().getType());
|
||||
if (failed(isInBounds(dim(), tensor())))
|
||||
return emitError("requested indices dimension out of bounds");
|
||||
if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
|
||||
return emitError("unexpected type for indices");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ToValuesOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor to get values");
|
||||
RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
|
||||
MemRefType mtp = result().getType().cast<MemRefType>();
|
||||
if (ttp.getElementType() != mtp.getElementType())
|
||||
@@ -271,46 +259,6 @@ LogicalResult ToValuesOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorDialect Management Operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult LexInsertOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor for insertion");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExpandOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor for expansion");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult CompressOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor for compression");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LoadOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor to materialize");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ReleaseOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor to release");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OutOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor for output");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorDialect Linalg.Generic Operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
|
||||
// expected-error@+1 {{expected a sparse tensor result}}
|
||||
// expected-error@+1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any type values, but got 'tensor<32xf32>'}}
|
||||
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<32xf32>
|
||||
return %0 : tensor<32xf32>
|
||||
}
|
||||
@@ -9,7 +9,7 @@ func.func @invalid_new_dense(%arg0: !llvm.ptr<i8>) -> tensor<32xf32> {
|
||||
// -----
|
||||
|
||||
func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
|
||||
// expected-error@+1 {{expected a sparse tensor to release}}
|
||||
// expected-error@+1 {{'sparse_tensor.release' op operand #0 must be sparse tensor of any type values, but got 'tensor<4xi32>'}}
|
||||
sparse_tensor.release %arg0 : tensor<4xi32>
|
||||
return
|
||||
}
|
||||
@@ -18,7 +18,7 @@ func.func @invalid_release_dense(%arg0: tensor<4xi32>) {
|
||||
|
||||
func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
|
||||
%c = arith.constant 0 : index
|
||||
// expected-error@+1 {{expected a sparse tensor to get pointers}}
|
||||
// expected-error@+1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
|
||||
%0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64> to memref<?xindex>
|
||||
return %0 : memref<?xindex>
|
||||
}
|
||||
@@ -27,7 +27,7 @@ func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
|
||||
|
||||
func.func @invalid_pointers_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
|
||||
%c = arith.constant 0 : index
|
||||
// expected-error@+1 {{expected a sparse tensor to get pointers}}
|
||||
// expected-error@+1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
|
||||
%0 = sparse_tensor.pointers %arg0, %c : tensor<*xf64> to memref<?xindex>
|
||||
return %0 : memref<?xindex>
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func.func @pointers_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex
|
||||
|
||||
func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
|
||||
%c = arith.constant 1 : index
|
||||
// expected-error@+1 {{expected a sparse tensor to get indices}}
|
||||
// expected-error@+1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<10x10xi32>'}}
|
||||
%0 = sparse_tensor.indices %arg0, %c : tensor<10x10xi32> to memref<?xindex>
|
||||
return %0 : memref<?xindex>
|
||||
}
|
||||
@@ -67,7 +67,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
|
||||
|
||||
func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
|
||||
%c = arith.constant 0 : index
|
||||
// expected-error@+1 {{expected a sparse tensor to get indices}}
|
||||
// expected-error@+1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
|
||||
%0 = sparse_tensor.indices %arg0, %c : tensor<*xf64> to memref<?xindex>
|
||||
return %0 : memref<?xindex>
|
||||
}
|
||||
@@ -97,7 +97,7 @@ func.func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex>
|
||||
// -----
|
||||
|
||||
func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
|
||||
// expected-error@+1 {{expected a sparse tensor to get values}}
|
||||
// expected-error@+1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any type values, but got 'tensor<1024xf32>'}}
|
||||
%0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref<?xf32>
|
||||
return %0 : memref<?xf32>
|
||||
}
|
||||
@@ -115,7 +115,7 @@ func.func @mismatch_values_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<
|
||||
// -----
|
||||
|
||||
func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> {
|
||||
// expected-error@+1 {{expected a sparse tensor to materialize}}
|
||||
// expected-error@+1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any type values, but got 'tensor<16x32xf64>'}}
|
||||
%0 = sparse_tensor.load %arg0 : tensor<16x32xf64>
|
||||
return %0 : tensor<16x32xf64>
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64
|
||||
// -----
|
||||
|
||||
func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xindex>, %arg2: f64) {
|
||||
// expected-error@+1 {{expected a sparse tensor for insertion}}
|
||||
// expected-error@+1 {{'sparse_tensor.lex_insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
|
||||
sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref<?xindex>, f64
|
||||
return
|
||||
}
|
||||
@@ -131,7 +131,7 @@ func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xind
|
||||
// -----
|
||||
|
||||
func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
|
||||
// expected-error@+1 {{expected a sparse tensor for expansion}}
|
||||
// expected-error@+1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
|
||||
%values, %filled, %added, %count = sparse_tensor.expand %arg0
|
||||
: tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
|
||||
return
|
||||
@@ -142,7 +142,7 @@ func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
|
||||
func.func @sparse_unannotated_compression(%arg0: tensor<128xf64>, %arg1: memref<?xindex>,
|
||||
%arg2: memref<?xf64>, %arg3: memref<?xi1>,
|
||||
%arg4: memref<?xindex>, %arg5: index) {
|
||||
// expected-error@+1 {{expected a sparse tensor for compression}}
|
||||
// expected-error@+1 {{'sparse_tensor.compress' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
|
||||
sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
|
||||
: tensor<128xf64>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
|
||||
}
|
||||
@@ -178,7 +178,7 @@ func.func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10x
|
||||
// -----
|
||||
|
||||
func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr<i8>) {
|
||||
// expected-error@+1 {{expected a sparse tensor for output}}
|
||||
// expected-error@+1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any type values, but got 'tensor<10xf64>'}}
|
||||
sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr<i8>
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user