[mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. (#66722)

The use cases of the two operations are largely overlapped, let's
simplify it and only use one of them.
This commit is contained in:
Peiming Liu
2023-09-19 17:02:32 -07:00
committed by GitHub
parent 74338bfe0c
commit bfa3bc4378
14 changed files with 267 additions and 678 deletions

View File

@@ -762,81 +762,32 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
// Sparse Tensor Sorting Operations.
//===----------------------------------------------------------------------===//
def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
Arguments<(ins Index:$n,
Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
SparseTensorSortKindAttr:$algorithm)> {
string summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
string description = [{
Lexicographically sort the first `n` values in `xs` along with the values in
`ys`. Conceptually, the values being sorted are tuples produced by
`zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
along with values in `xs`, but values in `ys` don't affect the
lexicographical order. The order in which arrays appear in `xs` affects the
sorting result. The operator updates `xs` and `ys` in place with the result
of the sorting.
For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
"sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
Buffers in `xs` needs to have the same integral element type while buffers
in `ys` can have different numeric element types. All buffers in `xs` and
`ys` should have a dimension not less than `n`. The behavior of the operator
is undefined if this condition is not met. The operator requires at least
one buffer in `xs` while `ys` can be empty.
The enum attribute `algorithm` indicates the sorting algorithm used to
implement the operator: hybrid_quick_sort, insertion_sort_stable,
quick_sort, or heap_sort.
Note that this operation is "impure" in the sense that its behavior is
solely defined by side-effects and not SSA values.
Example:
```mlir
sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
```
```mlir
sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
{ alg=1 : index}
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
```
}];
let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
"`:` type($xs) (`jointly` type($ys)^)?";
let hasVerifier = 1;
}
def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
SparseTensorSortKindAttr:$algorithm)> {
let summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
let description = [{
Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
`xs` values and some `ys` values are put in the linear buffer `xy`. The
optional index attribute `nx` provides the number of `xs` values in `xy`.
When `nx` is not explicitly specified, its value is 1. The optional index
attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
explicitly specified, its value is 0. This instruction supports a more
efficient way to store the COO definition in sparse tensor type.
Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
that are put in a single linear buffer `xy`.
The affine map attribute `perm_map` specifies the permutation to be applied on
the `xs` before comparison, the rank of the permutation map
also specifies the number of `xs` values in `xy`.
The optional index attribute `ny` provides the number of `ys` values in `xy`.
When `ny` is not explicitly specified, its value is 0.
This instruction supports a more efficient way to store the COO definition
in sparse tensor type.
The buffer xy should have a dimension not less than n * (nx + ny) while the
The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the
buffers in `ys` should have a dimension not less than `n`. The behavior of
the operator is undefined if this condition is not met.
Example:
```mlir
sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
: memref<?xindex>
```

View File

@@ -1353,35 +1353,15 @@ LogicalResult SelectOp::verify() {
return success();
}
LogicalResult SortOp::verify() {
if (getXs().empty())
return emitError("need at least one xs buffer.");
std::optional<int64_t> n = getConstantIntValue(getN());
Type xtp = getMemRefType(getXs().front()).getElementType();
auto checkTypes = [&](ValueRange operands,
bool checkEleType = true) -> LogicalResult {
for (Value opnd : operands) {
auto mtp = getMemRefType(opnd);
const DynSize sh = mtp.getShape()[0];
// We can't check the size of dynamic dimension at compile-time, but all
// xs and ys should have a dimension not less than n at runtime.
if (n && !ShapedType::isDynamic(sh) && sh < n.value())
return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
": {0} < {1}",
sh, n.value()));
if (checkEleType && xtp != mtp.getElementType())
return emitError("mismatch xs element types");
}
return success();
};
RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
return n ? checkTypes(getYs(), false) : success();
}
LogicalResult SortCooOp::verify() {
AffineMap xPerm = getPermMap();
uint64_t nx = xPerm.getNumDims();
if (nx < 1)
emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
if (!xPerm.isPermutation())
emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
std::optional<int64_t> cn = getConstantIntValue(getN());
// We can't check the size of the buffers when n or buffer dimensions aren't
// compile-time constants.
@@ -1389,12 +1369,6 @@ LogicalResult SortCooOp::verify() {
return success();
uint64_t n = cn.value();
uint64_t nx = 1;
if (auto nxAttr = getNxAttr()) {
nx = nxAttr.getInt();
if (nx < 1)
emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
}
uint64_t ny = 0;
if (auto nyAttr = getNyAttr()) {
ny = nyAttr.getInt();
@@ -1409,7 +1383,8 @@ LogicalResult SortCooOp::verify() {
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
};
checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
checkDim(getXy(), n * (nx + ny),
"Expected dimension(xy) >= n * (rank(perm_map) + ny)");
for (Value opnd : getYs()) {
checkDim(opnd, n, "Expected dimension(y) >= n");

View File

@@ -45,46 +45,43 @@ static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
using FuncGeneratorType = function_ref<void(
OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
AffineMap, uint64_t, uint32_t)>;
/// Constructs a function name with this format to facilitate quick sort:
/// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
/// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
/// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
/// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
StringRef namePrefix, uint64_t nx,
uint64_t ny, bool isCoo,
ValueRange operands) {
nameOstream << namePrefix << nx << "_"
<< getMemRefType(operands[xStartIdx]).getElementType();
StringRef namePrefix, AffineMap xPerm,
uint64_t ny, ValueRange operands) {
nameOstream << namePrefix;
for (auto res : xPerm.getResults())
nameOstream << res.cast<AffineDimExpr>().getPosition() << "_";
if (isCoo)
nameOstream << "_coo_" << ny;
nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
nameOstream << "_coo_" << ny;
uint64_t yBufferOffset = isCoo ? 1 : nx;
constexpr uint64_t yBufferOffset = 1;
for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
nameOstream << "_" << getMemRefType(v).getElementType();
}
/// Looks up a function that is appropriate for the given operands being
/// sorted, and creates such a function if it doesn't exist yet. The
/// parameters `nx` and `ny` tell the number of x and y values provided
/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
/// parameters `xPerm` and `ny` tell the number of x and y values provided
/// by the buffer in xStartIdx.
//
// All sorting function generators take (lo, hi, xs, ys) in `operands` as
// parameters for the sorting functions. Other parameters, such as the recursive
// call depth, are appended to the end of the parameter list as
// "trailing parameters".
static FlatSymbolRefAttr
getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
TypeRange resultTypes, StringRef namePrefix,
uint64_t nx, uint64_t ny, bool isCoo,
ValueRange operands, FuncGeneratorType createFunc,
uint32_t nTrailingP = 0) {
static FlatSymbolRefAttr getMangledSortHelperFunc(
OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
operands.drop_back(nTrailingP));
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
@@ -101,7 +98,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
loc, nameOstream.str(),
FunctionType::get(context, operands.getTypes(), resultTypes));
func.setPrivate();
createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
createFunc(builder, module, func, xPerm, ny, nTrailingP);
}
return result;
@@ -110,27 +107,19 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInXs(
OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
Value iOffset, jOffset;
if (isCoo) {
Value cstep = constantIndex(builder, loc, nx + ny);
iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
}
for (uint64_t k = 0; k < nx; k++) {
scf::IfOp ifOp;
Value i, j, buffer;
if (isCoo) {
Value ck = constantIndex(builder, loc, k);
i = builder.create<arith::AddIOp>(loc, ck, iOffset);
j = builder.create<arith::AddIOp>(loc, ck, jOffset);
buffer = args[xStartIdx];
} else {
i = args[0];
j = args[1];
buffer = args[xStartIdx + k];
}
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
uint64_t ny,
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
unsigned actualK = xPerm.getResult(k).cast<AffineDimExpr>().getPosition();
Value ak = constantIndex(builder, loc, actualK);
Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
Value buffer = args[xStartIdx];
bodyBuilder(k, i, j, buffer);
}
}
@@ -138,21 +127,28 @@ static void forEachIJPairInXs(
/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInAllBuffers(
OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
uint64_t ny,
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
// Create code for the first (nx + ny) buffers. When isCoo==true, these
// logical buffers are all from the xy buffer of the sort_coo operator.
forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
// Create code for the first (xPerm + ny) buffers.
SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
xPerm.getResults().end());
for (unsigned y = 0; y < ny; y++) {
exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
}
AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
assert(xyPerm.isPermutation());
uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
constexpr uint64_t numHandledBuffers = 1;
// Create code for the remaining buffers.
Value i = args[0];
Value j = args[1];
for (const auto &arg :
llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
}
}
@@ -168,7 +164,7 @@ static void forEachIJPairInAllBuffers(
// ...
// swap(yn[i], yn[j]);
static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
uint64_t nx, uint64_t ny, bool isCoo) {
AffineMap xPerm, uint64_t ny) {
auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
@@ -176,20 +172,20 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
builder.create<memref::StoreOp>(loc, vi, buffer, j);
};
forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
}
/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
/// each pair is create via `compareBuilder`.
static Value createInlinedCompareImplementation(
OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
bool isCoo,
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
uint64_t ny,
function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
compareBuilder) {
Value result;
auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
bool isFirstDim = (k == 0);
bool isLastDim = (k == nx - 1);
bool isLastDim = (k == xPerm.getNumResults() - 1);
Value val =
compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
if (isFirstDim) {
@@ -202,7 +198,7 @@ static Value createInlinedCompareImplementation(
}
};
forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
builder.setInsertionPointAfterValue(result);
return result;
@@ -252,12 +248,12 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
// else if (x2[2] != x2[j]))
// and so on ...
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
ValueRange args, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
ValueRange args, AffineMap xPerm,
uint64_t ny, uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
createEqCompare);
}
@@ -306,12 +302,12 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
// else
// and so on ...
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
ValueRange args, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
ValueRange args, AffineMap xPerm,
uint64_t ny, uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
createLessThanCompare);
}
@@ -329,8 +325,8 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc,
// return lo;
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
func::FuncOp func, AffineMap xPerm,
uint64_t ny, uint32_t nTrailingP = 0) {
// Binary search doesn't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -368,11 +364,10 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
// Compare xs[p] < xs[mid].
SmallVector<Value> compareOperands{p, mid};
uint64_t numXBuffers = isCoo ? 1 : nx;
constexpr uint64_t numXBuffers = 1;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
Value cond2 =
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
// Update lo and hi for the WhileOp as follows:
// if (xs[p] < xs[mid]))
// hi = mid;
@@ -392,10 +387,11 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
/// while (xs[i] > xs[p]) i += step (step < 0)
/// The routine returns i as well as a boolean value to indicate whether
/// xs[i] == xs[p].
static std::pair<Value, Value>
createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
bool isCoo, int step) {
static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
ModuleOp module,
func::FuncOp func, ValueRange xs,
Value i, Value p, AffineMap xPerm,
uint64_t ny, int step) {
Location loc = func.getLoc();
scf::WhileOp whileOp =
builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
@@ -413,8 +409,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
compareOperands.push_back(before->getArgument(0));
}
compareOperands.append(xs.begin(), xs.end());
Value cond =
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
Block *after =
@@ -429,7 +424,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
compareOperands[0] = i;
compareOperands[1] = p;
Value compareEq =
createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo);
createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
return std::make_pair(whileOp.getResult(0), compareEq);
}
@@ -438,67 +433,63 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
/// right after the swap instructions.
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
uint64_t nx, uint64_t ny, bool isCoo,
AffineMap xPerm, uint64_t ny,
SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands,
Value a, Value b) {
// Compare(data[b], data[a]).
compareOperands[0] = b;
compareOperands[1] = a;
Value cond =
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
swapOperands[0] = b;
swapOperands[1] = a;
createSwap(builder, loc, swapOperands, nx, ny, isCoo);
createSwap(builder, loc, swapOperands, xPerm, ny);
return ifOp;
}
/// Creates code to insert the 3rd element to a list of two sorted elements.
static void createInsert3rd(OpBuilder &builder, Location loc, uint64_t nx,
uint64_t ny, bool isCoo,
SmallVectorImpl<Value> &swapOperands,
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2) {
scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
swapOperands, compareOperands, v1, v2);
createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands,
compareOperands, v0, v1);
scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
compareOperands, v1, v2);
createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
v0, v1);
builder.setInsertionPointAfter(ifOp);
}
/// Creates code to sort 3 elements.
static void createSort3(OpBuilder &builder, Location loc, uint64_t nx,
uint64_t ny, bool isCoo,
SmallVectorImpl<Value> &swapOperands,
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2) {
// Sort the first 2 elements.
scf::IfOp ifOp1 = createCompareThenSwap(
builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0, v1);
scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
compareOperands, v0, v1);
builder.setInsertionPointAfter(ifOp1);
// Insert the 3th element.
createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
v0, v1, v2);
createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
v1, v2);
}
/// Creates code to sort 5 elements.
static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
uint64_t ny, bool isCoo,
SmallVectorImpl<Value> &swapOperands,
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2, Value v3, Value v4) {
// Sort the first 3 elements.
createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0,
v1, v2);
createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
v2);
auto insert4th = [&]() {
scf::IfOp ifOp = createCompareThenSwap(
builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v2, v3);
createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
v0, v1, v2);
builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
v1, v2);
builder.setInsertionPointAfter(ifOp);
};
@@ -506,8 +497,8 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
insert4th();
// Insert the 5th element.
scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
swapOperands, compareOperands, v3, v4);
scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
compareOperands, v3, v4);
insert4th();
builder.setInsertionPointAfter(ifOp);
}
@@ -517,11 +508,10 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
/// the number of values in range [lo, hi) is more than a threshold, we also
/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
static void createChoosePivot(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, Value lo, Value hi, Value mi,
ValueRange args) {
func::FuncOp func, AffineMap xPerm, uint64_t ny,
Value lo, Value hi, Value mi, ValueRange args) {
SmallVector<Value> compareOperands{mi, lo};
uint64_t numXBuffers = isCoo ? 1 : nx;
constexpr uint64_t numXBuffers = 1;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
SmallVector<Value> swapOperands{mi, lo};
@@ -537,8 +527,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
// When len < 1000, choose pivot from median of 3 values.
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo,
mi, hi);
createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
hi);
// When len >= 1000, choose pivot from median of 5 values.
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
@@ -549,8 +539,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
// Value b is the middle between [mi, hi].
b = builder.create<arith::ShRUIOp>(loc, b, c1);
createSort5(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo, a,
mi, b, hi);
createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
b, hi);
builder.setInsertionPointAfter(lenIf);
}
@@ -586,8 +576,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
// }
// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
func::FuncOp func, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP = 0) {
// Quick sort partition doesn't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -606,7 +596,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
Value i = lo;
Value j = builder.create<arith::SubIOp>(loc, hi, c1);
createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
Value trueVal = constantI1(builder, loc, true); // The value for while (true)
SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
@@ -628,14 +618,14 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
j = after->getArgument(1);
p = after->getArgument(2);
uint64_t numXBuffers = isCoo ? 1 : nx;
constexpr uint64_t numXBuffers = 1;
auto [iresult, iCompareEq] =
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
i, p, nx, ny, isCoo, 1);
i, p, xPerm, ny, 1);
i = iresult;
auto [jresult, jCompareEq] =
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
j, p, nx, ny, isCoo, -1);
j, p, xPerm, ny, -1);
j = jresult;
// If i < j:
@@ -645,7 +635,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
SmallVector<Value> swapOperands{i, j};
swapOperands.append(args.begin() + xStartIdx, args.end());
createSwap(builder, loc, swapOperands, nx, ny, isCoo);
createSwap(builder, loc, swapOperands, xPerm, ny);
// If the pivot is moved, update p with the new pivot.
Value icond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
@@ -737,8 +727,8 @@ static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
// }
//
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP) {
func::FuncOp func, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP) {
// The value n is passed in as a trailing parameter.
assert(nTrailingP == 1);
OpBuilder::InsertionGuard insertionGuard(builder);
@@ -768,7 +758,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
Value c1 = constantIndex(builder, loc, 1);
SmallVector<Value> compareOperands{start, start};
uint64_t numXBuffers = isCoo ? 1 : nx;
constexpr uint64_t numXBuffers = 1;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
@@ -794,7 +784,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
compareOperands[0] = lChildIdx;
compareOperands[1] = rChildIdx;
Value cond2 =
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
scf::IfOp if2 =
builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
builder.setInsertionPointToStart(&if2.getThenRegion().front());
@@ -825,8 +815,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
childIdx = before->getArgument(2);
compareOperands[0] = start;
compareOperands[1] = childIdx;
Value cond =
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
// The after-region of the WhileOp.
@@ -836,7 +825,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
childIdx = after->getArgument(2);
SmallVector<Value> swapOperands{start, childIdx};
swapOperands.append(args.begin() + xStartIdx, args.end());
createSwap(builder, loc, swapOperands, nx, ny, isCoo);
createSwap(builder, loc, swapOperands, xPerm, ny);
start = childIdx;
Value cond2 =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
@@ -869,8 +858,8 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
// shiftdown(lo, lo, l-1)
// }
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP) {
func::FuncOp func, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP) {
// Heap sort function doesn't have trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -897,7 +886,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
shiftDownOperands.append(args.begin() + xStartIdx, args.end());
shiftDownOperands.push_back(n);
FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo,
builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
shiftDownOperands);
@@ -912,7 +901,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
SmallVector<Value> swapOperands{lo, loplm1};
swapOperands.append(args.begin() + xStartIdx, args.end());
createSwap(builder, loc, swapOperands, nx, ny, isCoo);
createSwap(builder, loc, swapOperands, xPerm, ny);
shiftDownOperands[1] = lo;
shiftDownOperands[shiftDownOperands.size() - 1] =
builder.create<arith::SubIOp>(loc, l, c1);
@@ -928,7 +917,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
/// the bigger partition to be processed by the enclosed while-loop.
static std::pair<Value, Value>
createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
ValueRange args, uint64_t nx, uint64_t ny, bool isCoo,
ValueRange args, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP) {
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
@@ -937,8 +926,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
ny, args.drop_back(nTrailingP), createPartitionFunc);
Value p = builder
.create<func::CallOp>(loc, partitionFunc,
TypeRange{IndexType::get(context)},
@@ -1008,8 +997,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
// }
// }
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP) {
func::FuncOp func, AffineMap xPerm,
uint64_t ny, uint32_t nTrailingP) {
// Stable sort function doesn't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -1034,8 +1023,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
SmallVector<Value> operands{lo, i};
operands.append(args.begin() + xStartIdx, args.end());
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
ny, isCoo, operands, createBinarySearchFunc);
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
xPerm, ny, operands, createBinarySearchFunc);
Value p = builder
.create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
operands)
@@ -1045,7 +1034,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
operands[0] = operands[1] = i;
SmallVector<Value> d;
forEachIJPairInAllBuffers(
builder, loc, operands, nx, ny, isCoo,
builder, loc, operands, xPerm, ny,
[&](uint64_t unused, Value i, Value unused2, Value buffer) {
d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
});
@@ -1061,7 +1050,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
operands[1] = imj;
operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
forEachIJPairInAllBuffers(
builder, loc, operands, nx, ny, isCoo,
builder, loc, operands, xPerm, ny,
[&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
builder.create<memref::StoreOp>(loc, t, buffer, imj);
@@ -1071,7 +1060,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointAfter(forOpJ);
operands[0] = operands[1] = p;
forEachIJPairInAllBuffers(
builder, loc, operands, nx, ny, isCoo,
builder, loc, operands, xPerm, ny,
[&](uint64_t k, Value p, Value usused, Value buffer) {
builder.create<memref::StoreOp>(loc, d[k], buffer, p);
});
@@ -1123,8 +1112,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
// }
//
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP) {
func::FuncOp func, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP) {
assert(nTrailingP == 1 || nTrailingP == 0);
bool isHybrid = (nTrailingP == 1);
OpBuilder::InsertionGuard insertionGuard(builder);
@@ -1173,7 +1162,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
// When len <= limit.
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
ValueRange(args).drop_back(nTrailingP));
@@ -1193,7 +1182,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
// When depth exceeds limit.
builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
ValueRange(args).drop_back(nTrailingP));
@@ -1203,7 +1192,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
args.back() = depthLimit;
std::tie(lo, hi) =
createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
builder.setInsertionPointAfter(depthIf);
@@ -1216,7 +1205,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
hi = lenIf.getResult(1);
} else {
std::tie(lo, hi) =
createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
}
// New [lo, hi) for the next while-loop iteration.
@@ -1229,9 +1218,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
/// Implements the rewriting for operator sort and sort_coo.
template <typename OpTy>
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
uint64_t ny, bool isCoo,
PatternRewriter &rewriter) {
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
uint64_t ny, PatternRewriter &rewriter) {
Location loc = op.getLoc();
SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
@@ -1285,8 +1273,8 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
}
FlatSymbolRefAttr func =
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
ny, isCoo, operands, funcGenerator, nTrailingP);
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
xPerm, ny, operands, funcGenerator, nTrailingP);
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
return success();
}
@@ -1296,7 +1284,6 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
//===---------------------------------------------------------------------===//
namespace {
/// Sparse rewriting rule for the push_back operator.
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
public:
@@ -1410,20 +1397,6 @@ private:
bool enableBufferInitialization;
};
/// Sparse rewriting rule for the sort operator.
struct SortRewriter : public OpRewritePattern<SortOp> {
public:
using OpRewritePattern<SortOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SortOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> xys(op.getXs());
xys.append(op.getYs().begin(), op.getYs().end());
return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
/*isCoo=*/false, rewriter);
}
};
/// Sparse rewriting rule for the sort_coo operator.
struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
public:
@@ -1434,16 +1407,13 @@ public:
SmallVector<Value> xys;
xys.push_back(op.getXy());
xys.append(op.getYs().begin(), op.getYs().end());
uint64_t nx = 1;
if (auto nxAttr = op.getNxAttr())
nx = nxAttr.getInt();
auto xPerm = op.getPermMap();
uint64_t ny = 0;
if (auto nyAttr = op.getNyAttr())
ny = nyAttr.getInt();
return matchAndRewriteSortOp(op, xys, nx, ny,
/*isCoo=*/true, rewriter);
return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
}
};
@@ -1457,5 +1427,5 @@ void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
bool enableBufferInitialization) {
patterns.add<PushBackRewriter>(patterns.getContext(),
enableBufferInitialization);
patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
patterns.add<SortCooRewriter>(patterns.getContext());
}

View File

@@ -890,8 +890,9 @@ public:
// If the innermost level is ordered, we need to sort the coordinates
// in the "added" array prior to applying the compression.
if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{},
SparseTensorSortKind::HybridQuickSort);
rewriter.create<SortCooOp>(
loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
// While performing the insertions, we also need to reset the elements
// of the values/filled-switch by only iterating over the set elements,
// to ensure that the runtime complexity remains proportional to the
@@ -1486,9 +1487,10 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
scf::IfOp ifOp =
rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
rewriter.create<SortCooOp>(
loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(lvlRank),
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
rewriter.create<SortCooOp>(loc, nse, xs, ValueRange{ys}, xPerm,
rewriter.getIndexAttr(0),
SparseTensorSortKind::HybridQuickSort);
rewriter.setInsertionPointAfter(ifOp);
}

View File

@@ -207,7 +207,6 @@ struct SparseTensorCodegenPass
ConversionTarget target(*ctx);
// Most ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addLegalOp<SortOp>();
target.addLegalOp<SortCooOp>();
target.addLegalOp<PushBackOp>();
// Storage specifier outlives sparse tensor pipeline.

View File

@@ -1206,29 +1206,23 @@ private:
// Retrieve the values-array.
Value y = genToValues(rewriter, loc, src);
const auto encSrc = srcTp.getEncoding();
// Sort the COO tensor so that its elements are ordered via increasing
// coordinates for the storage ordering of the dst tensor. Use SortCoo
// if the COO tensor has the same ordering as the dst tensor.
if (dimRank > 1 && srcTp.hasSameDimToLvl(dstTp)) {
Value xs = genToCoordinatesBuffer(rewriter, loc, src);
rewriter.create<SortCooOp>(
loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank),
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
} else {
// Gather the coordinates-arrays in the dst tensor storage order.
SmallVector<Value> xs(dstLvlRank);
const Level srcLvlRank = srcTp.getLvlRank();
for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) {
// FIXME: `toOrigDim` is deprecated
Dimension dim = toOrigDim(encSrc, srcLvl);
// FIXME: `toStoredDim` is deprecated
Level dstLvl = toStoredDim(encDst, dim);
xs[dstLvl] =
genToCoordinates(rewriter, loc, src, srcLvl, /*cooStart=*/0);
}
rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y},
SparseTensorSortKind::HybridQuickSort);
// Builds the dstLvl -> srcLvl permutation maps.
SmallVector<AffineExpr> es(dstLvlRank);
const Level srcLvlRank = srcTp.getLvlRank();
for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) {
// FIXME: `toOrigDim` is deprecated
Dimension dim = toOrigDim(encSrc, srcLvl);
// FIXME: `toStoredDim` is deprecated
Level dstLvl = toStoredDim(encDst, dim);
es[dstLvl] = rewriter.getAffineDimExpr(srcLvl);
}
auto xPerm = AffineMap::get(dstLvlRank, 0, es, rewriter.getContext());
assert(xPerm.isPermutation()); // must be a permutation.
Value xs = genToCoordinatesBuffer(rewriter, loc, src);
rewriter.create<SortCooOp>(loc, nnz, xs, ValueRange{y}, xPerm,
rewriter.getIndexAttr(0),
SparseTensorSortKind::HybridQuickSort);
}
// For each element in the COO tensor, insert the element to the dst tensor.

View File

@@ -75,123 +75,64 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// -----
// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index
// CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index
// CHECK-LABEL: func.func @sparse_sort_1d2v_quick
func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
-> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
}
// -----
// Only check the generated supporting function now. We have integration test
// to verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
// CHECK-DAG: func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-LABEL: func.func @sparse_sort_3d_quick
func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
sparse_tensor.sort quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}
// -----
// Only check the generated supporting function now. We have integration test
// to verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: i64) {
// CHECK-LABEL: func.func @sparse_sort_3d_hybrid
func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}
// -----
#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-LABEL: func.func @sparse_sort_3d_stable
func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}
// -----
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-LABEL: func.func @sparse_sort_3d_heap
func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
sparse_tensor.sort heap_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}
// -----
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
// CHECK-DAG: func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
// CHECK-DAG: func.func private @_sparse_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-LABEL: func.func @sparse_sort_coo_quick
func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
// -----
#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
// CHECK-LABEL: func.func @sparse_sort_coo_hybrid
func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
// -----
#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-LABEL: func.func @sparse_sort_coo_stable
func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
// -----
#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-LABEL: func.func @sparse_sort_coo_heap
func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}

View File

@@ -436,7 +436,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -484,7 +484,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
// CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[A12:.*]] = arith.constant 1 : index
// CHECK: %[[A13:.*]] = arith.constant 0 : index
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
@@ -712,7 +712,7 @@ func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) ->
// CHECK: %[[A33:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A5]], %[[A32]], %[[A14]], %[[A15]])
// CHECK: %[[A34:.*]] = arith.cmpi eq, %[[A33]], %[[A1]] : i1
// CHECK: scf.if %[[A34]] {
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {nx = 2 : index, ny = 0 : index} : memref<?xindex> jointly memref<?xf32>
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref<?xindex> jointly memref<?xf32>
// CHECK: }
// CHECK: memref.store %[[A10]], %[[A27]]{{\[}}%[[A2]]] : memref<?xindex>
// CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A30]] crd_mem_sz at 0 with %[[A11]]

View File

@@ -178,7 +178,7 @@ func.func @sparse_convert_singleton(%arg0: tensor<?xf32, #SparseSingleton64>) ->
// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_17:.*]] hasInserts : tensor<?x?x?xf32, #{{.*}}>>
// CHECK-RWT: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xf32>
// CHECK-RWT: %[[VAL_19:.*]] = sparse_tensor.coordinates_buffer %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xindex>
// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {nx = 3 : index, ny = 0 : index}
// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {ny = 0 : index, perm_map = #map}
// CHECK-RWT: %[[VAL_20:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) size_hint=%[[VAL_7]]
// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.foreach in %[[VAL_16]] init(%[[VAL_20]])
// CHECK-RWT: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: tensor<?x?x?xf32, #{{.*}}>>):

View File

@@ -790,60 +790,51 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
return
}
// -----
// TODO: a test case with empty xs doesn't work due to some parser issues.
func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
// expected-error@+1 {{operand #1 must be 1D memref of integer or index values}}
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1: memref<?xf32>
}
// -----
func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
%i20 = arith.constant 20 : index
// expected-error@+1 {{xs and ys need to have a dimension >= n: 10 < 20}}
sparse_tensor.sort insertion_sort_stable %i20, %arg0 : memref<10xindex>
return
}
// -----
func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) {
// expected-error@+1 {{mismatch xs element types}}
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
return
}
// -----
#MAP = affine_map<(i,j) -> (i,j)>
func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
// expected-error@+1 {{operand #1 must be 1D memref of integer or index values}}
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1: memref<?xf32>
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 {perm_map = #MAP} : memref<?xf32>
return
}
// -----
#MAP = affine_map<(i,j) -> (i,j)>
func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
%i20 = arith.constant 20 : index
// expected-error@+1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}}
sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
// expected-error@+1 {{Expected dimension(xy) >= n * (rank(perm_map) + ny) got 50 < 60}}
sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {perm_map = #MAP, ny = 1 : index} : memref<50xindex>
return
}
// -----
#MAP = affine_map<(i,j) -> (i,j)>
func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) {
%i20 = arith.constant 20 : index
// expected-error@+1 {{Expected dimension(y) >= n got 10 < 20}}
sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {perm_map = #MAP, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
return
}
// -----
#NON_PERM_MAP = affine_map<(i,j) -> (i,i)>
func.func @sparse_sort_coo_no_perm(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
// expected-error@+1 {{Expected a permutation map, got (d0, d1) -> (d0, d0)}}
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #NON_PERM_MAP, ny = 1 : index}: memref<?xindex>
return %arg1 : memref<?xindex>
}
// -----
#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {

View File

@@ -612,79 +612,29 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
// -----
// CHECK-LABEL: func @sparse_sort_1d0v(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xindex>)
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] : memref<?xindex>
// CHECK: return %[[B]]
func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 : memref<?xindex>
return %arg1 : memref<?xindex>
}
// -----
// CHECK-LABEL: func @sparse_sort_1d2v(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<20xindex>,
// CHECK-SAME: %[[C:.*]]: memref<10xindex>,
// CHECK-SAME: %[[D:.*]]: memref<?xf32>)
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
// CHECK: return %[[B]], %[[C]], %[[D]]
func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref<?xf32>) -> (memref<20xindex>, memref<10xindex>, memref<?xf32>) {
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref<?xf32>
}
// -----
// CHECK-LABEL: func @sparse_sort_2d1v(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
// CHECK: return %[[B]], %[[C]], %[[D]]
func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
}
// -----
// CHECK-LABEL: func @sparse_sort_stable(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
// CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
// CHECK: return %[[B]], %[[C]], %[[D]]
func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
}
// -----
#ID_MAP = affine_map<(i,j) -> (i,j)>
// CHECK-LABEL: func @sparse_sort_coo(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xindex>)
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {nx = 2 : index, ny = 1 : index} : memref<?xindex>
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {ny = 1 : index, perm_map = #{{.*}}} : memref<?xindex>
// CHECK: return %[[B]]
func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {nx = 2 : index, ny = 1 : index}: memref<?xindex>
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xindex>
return %arg1 : memref<?xindex>
}
// -----
#ID_MAP = affine_map<(i,j) -> (i,j)>
// CHECK-LABEL: func @sparse_sort_coo_stable(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xi64>,
// CHECK-SAME: %[[C:.*]]: memref<?xf32>)
// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {nx = 2 : index, ny = 1 : index}
// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {ny = 1 : index, perm_map = #{{.*}}}
// CHECK: return %[[B]], %[[C]]
func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: memref<?xf32>) -> (memref<?xi64>, memref<?xf32>) {
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {nx = 2 : index, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
}

View File

@@ -116,7 +116,7 @@
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: scf.yield %[[VAL_64:.*]] : index
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] : memref<?xindex>
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]]
// CHECK: %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex>
// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64>

View File

@@ -1,187 +0,0 @@
//--------------------------------------------------------------------------------------------------
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
//
// Set-up that's shared across all tests in this directory. In principle, this
// config could be moved to lit.local.cfg. However, there are downstream users that
// do not use these LIT config files. Hence why this is kept inline.
//
// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
// DEFINE: %{run_opts} = -e entry -entry-point-result=void
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
//
// DEFINE: %{env} =
//--------------------------------------------------------------------------------------------------
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
// RUN: %{compile} | %{run} | FileCheck %s
//
// Do the same run, but now with vectorization.
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
// RUN: %{compile} | %{run} | FileCheck %s
//
// Do the same run, but now with VLA vectorization.
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
module {
func.func private @printMemref1dI32(%ptr : memref<?xi32>) attributes { llvm.emit_c_interface }
// Stores 5 values to the memref buffer.
func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
%v3: i32, %v4: i32) -> () {
%i0 = arith.constant 0 : index
%i1 = arith.constant 1 : index
%i2 = arith.constant 2 : index
%i3 = arith.constant 3 : index
%i4 = arith.constant 4 : index
memref.store %v0, %b[%i0] : memref<?xi32>
memref.store %v1, %b[%i1] : memref<?xi32>
memref.store %v2, %b[%i2] : memref<?xi32>
memref.store %v3, %b[%i3] : memref<?xi32>
memref.store %v4, %b[%i4] : memref<?xi32>
return
}
// The main driver.
func.func @entry() {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c2 = arith.constant 2 : i32
%c3 = arith.constant 3 : i32
%c4 = arith.constant 4 : i32
%c5 = arith.constant 5 : i32
%c6 = arith.constant 6 : i32
%c7 = arith.constant 7 : i32
%c8 = arith.constant 8 : i32
%c9 = arith.constant 9 : i32
%c10 = arith.constant 10 : i32
%c100 = arith.constant 100 : i32
%i0 = arith.constant 0 : index
%i4 = arith.constant 4 : index
%i5 = arith.constant 5 : index
// Prepare a buffer.
%x0s = memref.alloc() : memref<5xi32>
%x0 = memref.cast %x0s : memref<5xi32> to memref<?xi32>
call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
// Sort 0 elements.
// Quick sort.
// CHECK: [10, 2, 0, 5, 1]
sparse_tensor.sort quick_sort %i0, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Stable sort.
// CHECK: [10, 2, 0, 5, 1]
sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Heap sort.
// CHECK: [10, 2, 0, 5, 1]
sparse_tensor.sort heap_sort %i0, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Hybrid sort.
// CHECK: [10, 2, 0, 5, 1]
sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Sort the first 4 elements, with the last valid value untouched.
// Quick sort.
// CHECK: [0, 2, 5, 10, 1]
sparse_tensor.sort quick_sort %i4, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Stable sort.
// CHECK: [0, 2, 5, 10, 1]
call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Heap sort.
// CHECK: [0, 2, 5, 10, 1]
call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort heap_sort %i4, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Hybrid sort.
// CHECK: [0, 2, 5, 10, 1]
sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Prepare more buffers of different dimensions.
%x1s = memref.alloc() : memref<10xi32>
%x1 = memref.cast %x1s : memref<10xi32> to memref<?xi32>
%x2s = memref.alloc() : memref<6xi32>
%x2 = memref.cast %x2s : memref<6xi32> to memref<?xi32>
%y0s = memref.alloc() : memref<7xi32>
%y0 = memref.cast %y0s : memref<7xi32> to memref<?xi32>
// Sort "parallel arrays".
// CHECK: [1, 1, 2, 5, 10]
// CHECK: [3, 3, 1, 10, 1
// CHECK: [9, 9, 4, 7, 2
// CHECK: [7, 8, 10, 9, 6
call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort quick_sort %i5, %x0, %x1, %x2 jointly %y0
: memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
// Stable sort.
// CHECK: [1, 1, 2, 5, 10]
// CHECK: [3, 3, 1, 10, 1
// CHECK: [9, 9, 4, 7, 2
// CHECK: [8, 7, 10, 9, 6
call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort insertion_sort_stable %i5, %x0, %x1, %x2 jointly %y0
: memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
// Heap sort.
// CHECK: [1, 1, 2, 5, 10]
// CHECK: [3, 3, 1, 10, 1
// CHECK: [9, 9, 4, 7, 2
// CHECK: [7, 8, 10, 9, 6
call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort heap_sort %i5, %x0, %x1, %x2 jointly %y0
: memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
// Release the buffers.
memref.dealloc %x0 : memref<?xi32>
memref.dealloc %x1 : memref<?xi32>
memref.dealloc %x2 : memref<?xi32>
memref.dealloc %y0 : memref<?xi32>
return
}
}

View File

@@ -28,6 +28,8 @@
// Do the same run, but now with VLA vectorization.
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
#ID_MAP = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
module {
// Stores 5 values to the memref buffer.
func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
@@ -94,11 +96,11 @@ module {
%y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
// Sort "parallel arrays".
// CHECK: ( 1, 1, 3, 3, 10 )
// CHECK: ( 2, 10, 1, 1, 5 )
// CHECK: ( 4, 2, 9, 9, 7 )
// CHECK: ( 10, 6, 7, 8, 9 )
// CHECK: ( 7, 5, 7, 4, 9 )
// CHECK: ( 1, 1, 2, 5, 10 )
// CHECK: ( 9, 9, 4, 7, 2 )
// CHECK: ( 3, 3, 1, 10, 1 )
// CHECK: ( 7, 8, 10, 9, 6 )
// CHECK: ( 7, 4, 7, 9, 5 )
call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
@@ -109,24 +111,25 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
%x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v : vector<5xi32>
// Dumps memory in the same order as the perm_map such that the output is ordered.
%x1v = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x1v : vector<5xi32>
%x2v = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x2v : vector<5xi32>
%x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v : vector<5xi32>
%y0v = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %y0v : vector<5xi32>
%y1v = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
vector.print %y1v : vector<5xi32>
// Stable sort.
// CHECK: ( 1, 1, 3, 3, 10 )
// CHECK: ( 2, 10, 1, 1, 5 )
// CHECK: ( 4, 2, 9, 9, 7 )
// CHECK: ( 10, 6, 8, 7, 9 )
// CHECK: ( 7, 5, 4, 7, 9 )
// CHECK: ( 1, 1, 2, 5, 10 )
// CHECK: ( 9, 9, 4, 7, 2 )
// CHECK: ( 3, 3, 1, 10, 1 )
// CHECK: ( 8, 7, 10, 9, 6 )
// CHECK: ( 4, 7, 7, 9, 5 )
call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
@@ -137,24 +140,24 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
%x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v2 : vector<5xi32>
%x1v2 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x1v2 : vector<5xi32>
%x2v2 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x2v2 : vector<5xi32>
%x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v2 : vector<5xi32>
%y0v2 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %y0v2 : vector<5xi32>
%y1v2 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
vector.print %y1v2 : vector<5xi32>
// Heap sort.
// CHECK: ( 1, 1, 3, 3, 10 )
// CHECK: ( 2, 10, 1, 1, 5 )
// CHECK: ( 4, 2, 9, 9, 7 )
// CHECK: ( 10, 6, 8, 7, 9 )
// CHECK: ( 7, 5, 4, 7, 9 )
// CHECK: ( 1, 1, 2, 5, 10 )
// CHECK: ( 9, 9, 4, 7, 2 )
// CHECK: ( 3, 3, 1, 10, 1 )
// CHECK: ( 7, 8, 10, 9, 6 )
// CHECK: ( 7, 4, 7, 9, 5 )
call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
@@ -165,14 +168,14 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
%x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v3 : vector<5xi32>
%x1v3 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x1v3 : vector<5xi32>
%x2v3 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x2v3 : vector<5xi32>
%x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v3 : vector<5xi32>
%y0v3 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %y0v3 : vector<5xi32>
%y1v3 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>