[mlir][sparse] Add sparse_tensor.select operation

The new select operation allows filtering of sparse tensors
by conditionally keeping or removing each element. This
can be used to remove negative values or select the upper
triangle of a matrix.

The select op has a single region which operates on a single
value and must return a boolean True to keep or False to drop.

Reviewed by: aartbik

Differential Revision: https://reviews.llvm.org/D133569
This commit is contained in:
Jim Kitchen
2022-09-13 15:22:53 -05:00
parent 67f08bf1bf
commit 07150fece5
4 changed files with 144 additions and 10 deletions

View File

@@ -604,11 +604,72 @@ def SparseTensor_ReduceOp : SparseTensor_Op<"reduce", [NoSideEffect, SameOperand
let hasVerifier = 1;
}
def SparseTensor_SelectOp : SparseTensor_Op<"select", [NoSideEffect, SameOperandsAndResultType]>,
Arguments<(ins AnyType:$x)>,
Results<(outs AnyType:$output)> {
let summary = "Select operation utilized within linalg.generic";
let description = [{
Defines an evaluation within a `linalg.generic` operation that takes a single
operand and decides whether or not to keep that operand in the output.
A single region must contain exactly one block taking one argument. The block
must end with a sparse_tensor.yield and the output type must be boolean.
Value threshold is an obvious usage of the select operation. However, by using
`linalg.index`, other useful selection can be achieved, such as selecting the
upper triangle of a matrix.
Example of selecting A >= 4.0:
```mlir
%C = bufferization.alloc_tensor...
%0 = linalg.generic #trait
ins(%A: tensor<?xf64, #SparseVector>)
outs(%C: tensor<?xf64, #SparseVector>) {
^bb0(%a: f64, %c: f64) :
%result = sparse_tensor.select %a : f64 {
^bb0(%arg0: f64):
%cf4 = arith.constant 4.0 : f64
%keep = arith.cmpf "uge", %arg0, %cf4 : f64
sparse_tensor.yield %keep : i1
}
linalg.yield %result : f64
} -> tensor<?xf64, #SparseVector>
```
Example of selecting lower triangle of a matrix:
```mlir
%C = bufferization.alloc_tensor...
%0 = linalg.generic #trait
ins(%A: tensor<?x?xf64, #CSR>)
outs(%C: tensor<?x?xf64, #CSR>) {
^bb0(%a: f64, %c: f64) :
%row = linalg.index 0 : index
%col = linalg.index 1 : index
%result = sparse_tensor.select %a : f64 {
^bb0(%arg0: f64):
%keep = arith.cmpf "olt", %col, %row : f64
sparse_tensor.yield %keep : i1
}
linalg.yield %result : f64
} -> tensor<?x?xf64, #CSR>
```
}];
let regions = (region SizedRegion<1>:$region);
let assemblyFormat = [{
$x attr-dict `:` type($x) $region
}];
let hasVerifier = 1;
}
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
Arguments<(ins AnyType:$result)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary` or `unary` block.
Yields a value from within a `binary`, `unary`, `reduce`,
or `select` block.
Example:

View File

@@ -458,12 +458,27 @@ LogicalResult ReduceOp::verify() {
// Check correct number of block arguments and return type.
Region &formula = getRegion();
if (!formula.empty()) {
regionResult = verifyNumBlockArgs(
this, formula, "reduce", TypeRange{inputType, inputType}, inputType);
if (failed(regionResult))
return regionResult;
}
regionResult = verifyNumBlockArgs(this, formula, "reduce",
TypeRange{inputType, inputType}, inputType);
if (failed(regionResult))
return regionResult;
return success();
}
LogicalResult SelectOp::verify() {
Builder b(getContext());
Type inputType = getX().getType();
Type boolType = b.getI1Type();
LogicalResult regionResult = success();
// Check correct number of block arguments and return type.
Region &formula = getRegion();
regionResult = verifyNumBlockArgs(this, formula, "select",
TypeRange{inputType}, boolType);
if (failed(regionResult))
return regionResult;
return success();
}
@@ -472,11 +487,11 @@ LogicalResult YieldOp::verify() {
// Check for compatible parent.
auto *parentOp = (*this)->getParentOp();
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
isa<ReduceOp>(parentOp))
isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp))
return success();
return emitOpError(
"expected parent op to be sparse_tensor unary, binary, or reduce");
return emitOpError("expected parent op to be sparse_tensor unary, binary, "
"reduce, or select");
}
//===----------------------------------------------------------------------===//

View File

@@ -355,6 +355,40 @@ func.func @invalid_reduce_wrong_yield(%arg0: f64, %arg1: f64) -> f64 {
// -----
func.func @invalid_select_num_args_mismatch(%arg0: f64) -> f64 {
// expected-error@+1 {{select region must have exactly 1 arguments}}
%r = sparse_tensor.select %arg0 : f64 {
^bb0(%x: f64, %y: f64):
%ret = arith.constant 1 : i1
sparse_tensor.yield %ret : i1
}
return %r : f64
}
// -----
func.func @invalid_select_return_type_mismatch(%arg0: f64) -> f64 {
// expected-error@+1 {{select region yield type mismatch}}
%r = sparse_tensor.select %arg0 : f64 {
^bb0(%x: f64):
sparse_tensor.yield %x : f64
}
return %r : f64
}
// -----
func.func @invalid_select_wrong_yield(%arg0: f64) -> f64 {
// expected-error@+1 {{select region must end with sparse_tensor.yield}}
%r = sparse_tensor.select %arg0 : f64 {
^bb0(%x: f64):
tensor.yield %x : f64
}
return %r : f64
}
// -----
#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
func.func @invalid_concat_less_inputs(%arg: tensor<9x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
// expected-error@+1 {{Need at least two tensors to concatenate.}}

View File

@@ -291,6 +291,30 @@ func.func @sparse_reduce_2d_to_1d(%arg0: f64, %arg1: f64) -> f64 {
#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
// CHECK-LABEL: func @sparse_select(
// CHECK-SAME: %[[A:.*]]: f64) -> f64 {
// CHECK: %[[Z:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[C1:.*]] = sparse_tensor.select %[[A]] : f64 {
// CHECK: ^bb0(%[[A1:.*]]: f64):
// CHECK: %[[B1:.*]] = arith.cmpf ogt, %[[A1]], %[[Z]] : f64
// CHECK: sparse_tensor.yield %[[B1]] : i1
// CHECK: }
// CHECK: return %[[C1]] : f64
// CHECK: }
func.func @sparse_select(%arg0: f64) -> f64 {
%cf0 = arith.constant 0.0 : f64
%r = sparse_tensor.select %arg0 : f64 {
^bb0(%x: f64):
%cmp = arith.cmpf "ogt", %x, %cf0 : f64
sparse_tensor.yield %cmp : i1
}
return %r : f64
}
// -----
#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
// CHECK-LABEL: func @concat_sparse_sparse(
// CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64
// CHECK-SAME: %[[A1:.*]]: tensor<3x4xf64