From 8c02ca1da5bdcb7f7e850afb24d95bb6d82d8971 Mon Sep 17 00:00:00 2001 From: bixia1 Date: Tue, 4 Oct 2022 10:37:38 -0700 Subject: [PATCH] [mlir][sparse] Add an attribute to the sort operator for stable sorting. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135181 --- .../SparseTensor/IR/SparseTensorOps.td | 20 +++++++++++++++---- .../SparseTensor/IR/SparseTensorDialect.cpp | 5 +++++ mlir/test/Dialect/SparseTensor/roundtrip.mlir | 14 +++++++++++++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 8cd1a01a2e33..4d1f23719ee0 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -415,7 +415,8 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, // and then use NonemptyVariadic<...>:$xs here. Arguments<(ins Index:$n, Variadic>:$xs, - Variadic>:$ys)> { + Variadic>:$ys, + UnitAttr:$stable)> { string summary = "Sorts the arrays in xs and ys lexicographically on the " "integral values found in the xs list"; string description = [{ @@ -437,6 +438,9 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, is undefined if this condition is not met. The operator requires at least one buffer in `xs` while `ys` can be empty. + The `stable` attribute indicates whether a stable sorting algorithm should + be used to implement the operator. + Note that this operation is "impure" in the sense that its behavior is solely defined by side-effects and not SSA values. The semantics may be refined over time as our sparse abstractions evolve. @@ -447,10 +451,18 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2 : memref, memref jointly memref, memref ``` - }]; - let assemblyFormat = "$n `,` $xs (`jointly` $ys^)? attr-dict" - "`:` type($xs) (`jointly` type($ys)^)?"; + ```mlir + sparse_tensor.sort stable %n, %x1, %x2 jointly y1, %y2 + : memref, memref jointly memref, memref + ``` + }]; + let assemblyFormat = "(`stable` $stable^)? $n" + "`,`$xs (`jointly` $ys^)? attr-dict" + "`:` type($xs) (`jointly` type($ys)^)?"; + let builders = [ + OpBuilder<(ins "Value":$n, "ValueRange":$xs, "ValueRange":$ys)> + ]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index d12ecb9d023f..9f12a5481e40 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -675,6 +675,11 @@ LogicalResult SelectOp::verify() { return success(); } +void SortOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value n, + ValueRange xs, ValueRange ys) { + build(odsBuilder, odsState, n, xs, ys, /*stable=*/false); +} + LogicalResult SortOp::verify() { if (getXs().empty()) return emitError("need at least one xs buffer."); diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index ee69af3d256a..fd850aacacae 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -423,3 +423,17 @@ func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20 sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> } + +// ----- + +// CHECK-LABEL: func @sparse_sort_stable( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: memref<10xi8>, +// CHECK-SAME: %[[C:.*]]: memref<20xi8>, +// CHECK-SAME: %[[D:.*]]: memref<10xf64>) +// CHECK: sparse_tensor.sort stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64> +// CHECK: return %[[B]], %[[C]], %[[D]] +func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) { + sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64> + return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64> +}