mirror of
https://github.com/intel/llvm.git
synced 2026-01-31 07:27:33 +08:00
[mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. (#66722)
The use cases of the two operations are largely overlapped, let's simplify it and only use one of them.
This commit is contained in:
@@ -762,81 +762,32 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
|
||||
// Sparse Tensor Sorting Operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
|
||||
Arguments<(ins Index:$n,
|
||||
Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
|
||||
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
|
||||
SparseTensorSortKindAttr:$algorithm)> {
|
||||
string summary = "Sorts the arrays in xs and ys lexicographically on the "
|
||||
"integral values found in the xs list";
|
||||
string description = [{
|
||||
Lexicographically sort the first `n` values in `xs` along with the values in
|
||||
`ys`. Conceptually, the values being sorted are tuples produced by
|
||||
`zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
|
||||
along with values in `xs`, but values in `ys` don't affect the
|
||||
lexicographical order. The order in which arrays appear in `xs` affects the
|
||||
sorting result. The operator updates `xs` and `ys` in place with the result
|
||||
of the sorting.
|
||||
|
||||
For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
|
||||
"sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
|
||||
output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
|
||||
|
||||
Buffers in `xs` needs to have the same integral element type while buffers
|
||||
in `ys` can have different numeric element types. All buffers in `xs` and
|
||||
`ys` should have a dimension not less than `n`. The behavior of the operator
|
||||
is undefined if this condition is not met. The operator requires at least
|
||||
one buffer in `xs` while `ys` can be empty.
|
||||
|
||||
The enum attribute `algorithm` indicates the sorting algorithm used to
|
||||
implement the operator: hybrid_quick_sort, insertion_sort_stable,
|
||||
quick_sort, or heap_sort.
|
||||
|
||||
Note that this operation is "impure" in the sense that its behavior is
|
||||
solely defined by side-effects and not SSA values.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
|
||||
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
|
||||
```
|
||||
|
||||
```mlir
|
||||
sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
|
||||
{ alg=1 : index}
|
||||
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
|
||||
```
|
||||
}];
|
||||
let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
|
||||
"`:` type($xs) (`jointly` type($ys)^)?";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
|
||||
Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
|
||||
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
|
||||
OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
|
||||
AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
|
||||
SparseTensorSortKindAttr:$algorithm)> {
|
||||
let summary = "Sorts the arrays in xs and ys lexicographically on the "
|
||||
"integral values found in the xs list";
|
||||
let description = [{
|
||||
Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
|
||||
`xs` values and some `ys` values are put in the linear buffer `xy`. The
|
||||
optional index attribute `nx` provides the number of `xs` values in `xy`.
|
||||
When `nx` is not explicitly specified, its value is 1. The optional index
|
||||
attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
|
||||
explicitly specified, its value is 0. This instruction supports a more
|
||||
efficient way to store the COO definition in sparse tensor type.
|
||||
Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
|
||||
that are put in a single linear buffer `xy`.
|
||||
The affine map attribute `perm_map` specifies the permutation to be applied on
|
||||
the `xs` before comparison, the rank of the permutation map
|
||||
also specifies the number of `xs` values in `xy`.
|
||||
The optional index attribute `ny` provides the number of `ys` values in `xy`.
|
||||
When `ny` is not explicitly specified, its value is 0.
|
||||
This instruction supports a more efficient way to store the COO definition
|
||||
in sparse tensor type.
|
||||
|
||||
The buffer xy should have a dimension not less than n * (nx + ny) while the
|
||||
The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the
|
||||
buffers in `ys` should have a dimension not less than `n`. The behavior of
|
||||
the operator is undefined if this condition is not met.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
|
||||
sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
|
||||
: memref<?xindex>
|
||||
```
|
||||
|
||||
|
||||
@@ -1353,35 +1353,15 @@ LogicalResult SelectOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SortOp::verify() {
|
||||
if (getXs().empty())
|
||||
return emitError("need at least one xs buffer.");
|
||||
|
||||
std::optional<int64_t> n = getConstantIntValue(getN());
|
||||
|
||||
Type xtp = getMemRefType(getXs().front()).getElementType();
|
||||
auto checkTypes = [&](ValueRange operands,
|
||||
bool checkEleType = true) -> LogicalResult {
|
||||
for (Value opnd : operands) {
|
||||
auto mtp = getMemRefType(opnd);
|
||||
const DynSize sh = mtp.getShape()[0];
|
||||
// We can't check the size of dynamic dimension at compile-time, but all
|
||||
// xs and ys should have a dimension not less than n at runtime.
|
||||
if (n && !ShapedType::isDynamic(sh) && sh < n.value())
|
||||
return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
|
||||
": {0} < {1}",
|
||||
sh, n.value()));
|
||||
|
||||
if (checkEleType && xtp != mtp.getElementType())
|
||||
return emitError("mismatch xs element types");
|
||||
}
|
||||
return success();
|
||||
};
|
||||
RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
|
||||
return n ? checkTypes(getYs(), false) : success();
|
||||
}
|
||||
|
||||
LogicalResult SortCooOp::verify() {
|
||||
AffineMap xPerm = getPermMap();
|
||||
uint64_t nx = xPerm.getNumDims();
|
||||
if (nx < 1)
|
||||
emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
|
||||
|
||||
if (!xPerm.isPermutation())
|
||||
emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
|
||||
|
||||
std::optional<int64_t> cn = getConstantIntValue(getN());
|
||||
// We can't check the size of the buffers when n or buffer dimensions aren't
|
||||
// compile-time constants.
|
||||
@@ -1389,12 +1369,6 @@ LogicalResult SortCooOp::verify() {
|
||||
return success();
|
||||
|
||||
uint64_t n = cn.value();
|
||||
uint64_t nx = 1;
|
||||
if (auto nxAttr = getNxAttr()) {
|
||||
nx = nxAttr.getInt();
|
||||
if (nx < 1)
|
||||
emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
|
||||
}
|
||||
uint64_t ny = 0;
|
||||
if (auto nyAttr = getNyAttr()) {
|
||||
ny = nyAttr.getInt();
|
||||
@@ -1409,7 +1383,8 @@ LogicalResult SortCooOp::verify() {
|
||||
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
|
||||
};
|
||||
|
||||
checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
|
||||
checkDim(getXy(), n * (nx + ny),
|
||||
"Expected dimension(xy) >= n * (rank(perm_map) + ny)");
|
||||
|
||||
for (Value opnd : getYs()) {
|
||||
checkDim(opnd, n, "Expected dimension(y) >= n");
|
||||
|
||||
@@ -45,46 +45,43 @@ static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
|
||||
static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
|
||||
static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
|
||||
|
||||
using FuncGeneratorType = function_ref<void(
|
||||
OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
|
||||
using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
|
||||
AffineMap, uint64_t, uint32_t)>;
|
||||
|
||||
/// Constructs a function name with this format to facilitate quick sort:
|
||||
/// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
|
||||
/// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
|
||||
/// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
|
||||
/// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
|
||||
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
|
||||
StringRef namePrefix, uint64_t nx,
|
||||
uint64_t ny, bool isCoo,
|
||||
ValueRange operands) {
|
||||
nameOstream << namePrefix << nx << "_"
|
||||
<< getMemRefType(operands[xStartIdx]).getElementType();
|
||||
StringRef namePrefix, AffineMap xPerm,
|
||||
uint64_t ny, ValueRange operands) {
|
||||
nameOstream << namePrefix;
|
||||
for (auto res : xPerm.getResults())
|
||||
nameOstream << res.cast<AffineDimExpr>().getPosition() << "_";
|
||||
|
||||
if (isCoo)
|
||||
nameOstream << "_coo_" << ny;
|
||||
nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
|
||||
nameOstream << "_coo_" << ny;
|
||||
|
||||
uint64_t yBufferOffset = isCoo ? 1 : nx;
|
||||
constexpr uint64_t yBufferOffset = 1;
|
||||
for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
|
||||
nameOstream << "_" << getMemRefType(v).getElementType();
|
||||
}
|
||||
|
||||
/// Looks up a function that is appropriate for the given operands being
|
||||
/// sorted, and creates such a function if it doesn't exist yet. The
|
||||
/// parameters `nx` and `ny` tell the number of x and y values provided
|
||||
/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
|
||||
/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
|
||||
/// parameters `xPerm` and `ny` tell the number of x and y values provided
|
||||
/// by the buffer in xStartIdx.
|
||||
//
|
||||
// All sorting function generators take (lo, hi, xs, ys) in `operands` as
|
||||
// parameters for the sorting functions. Other parameters, such as the recursive
|
||||
// call depth, are appended to the end of the parameter list as
|
||||
// "trailing parameters".
|
||||
static FlatSymbolRefAttr
|
||||
getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
|
||||
TypeRange resultTypes, StringRef namePrefix,
|
||||
uint64_t nx, uint64_t ny, bool isCoo,
|
||||
ValueRange operands, FuncGeneratorType createFunc,
|
||||
uint32_t nTrailingP = 0) {
|
||||
static FlatSymbolRefAttr getMangledSortHelperFunc(
|
||||
OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
|
||||
StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
|
||||
FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
|
||||
SmallString<32> nameBuffer;
|
||||
llvm::raw_svector_ostream nameOstream(nameBuffer);
|
||||
getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
|
||||
getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
|
||||
operands.drop_back(nTrailingP));
|
||||
|
||||
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
|
||||
@@ -101,7 +98,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
|
||||
loc, nameOstream.str(),
|
||||
FunctionType::get(context, operands.getTypes(), resultTypes));
|
||||
func.setPrivate();
|
||||
createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
|
||||
createFunc(builder, module, func, xPerm, ny, nTrailingP);
|
||||
}
|
||||
|
||||
return result;
|
||||
@@ -110,27 +107,19 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
|
||||
/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
|
||||
/// The code to process the value pairs is generated by `bodyBuilder`.
|
||||
static void forEachIJPairInXs(
|
||||
OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
|
||||
Value iOffset, jOffset;
|
||||
if (isCoo) {
|
||||
Value cstep = constantIndex(builder, loc, nx + ny);
|
||||
iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
|
||||
jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
|
||||
}
|
||||
for (uint64_t k = 0; k < nx; k++) {
|
||||
scf::IfOp ifOp;
|
||||
Value i, j, buffer;
|
||||
if (isCoo) {
|
||||
Value ck = constantIndex(builder, loc, k);
|
||||
i = builder.create<arith::AddIOp>(loc, ck, iOffset);
|
||||
j = builder.create<arith::AddIOp>(loc, ck, jOffset);
|
||||
buffer = args[xStartIdx];
|
||||
} else {
|
||||
i = args[0];
|
||||
j = args[1];
|
||||
buffer = args[xStartIdx + k];
|
||||
}
|
||||
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
|
||||
uint64_t ny,
|
||||
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
|
||||
Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
|
||||
Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
|
||||
Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
|
||||
for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
|
||||
unsigned actualK = xPerm.getResult(k).cast<AffineDimExpr>().getPosition();
|
||||
Value ak = constantIndex(builder, loc, actualK);
|
||||
Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
|
||||
Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
|
||||
Value buffer = args[xStartIdx];
|
||||
|
||||
bodyBuilder(k, i, j, buffer);
|
||||
}
|
||||
}
|
||||
@@ -138,21 +127,28 @@ static void forEachIJPairInXs(
|
||||
/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
|
||||
/// The code to process the value pairs is generated by `bodyBuilder`.
|
||||
static void forEachIJPairInAllBuffers(
|
||||
OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
|
||||
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
|
||||
uint64_t ny,
|
||||
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
|
||||
|
||||
// Create code for the first (nx + ny) buffers. When isCoo==true, these
|
||||
// logical buffers are all from the xy buffer of the sort_coo operator.
|
||||
forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
|
||||
// Create code for the first (xPerm + ny) buffers.
|
||||
SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
|
||||
xPerm.getResults().end());
|
||||
for (unsigned y = 0; y < ny; y++) {
|
||||
exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
|
||||
}
|
||||
AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
|
||||
assert(xyPerm.isPermutation());
|
||||
|
||||
uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
|
||||
forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
|
||||
|
||||
constexpr uint64_t numHandledBuffers = 1;
|
||||
// Create code for the remaining buffers.
|
||||
Value i = args[0];
|
||||
Value j = args[1];
|
||||
for (const auto &arg :
|
||||
llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
|
||||
bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
|
||||
bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,7 +164,7 @@ static void forEachIJPairInAllBuffers(
|
||||
// ...
|
||||
// swap(yn[i], yn[j]);
|
||||
static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
|
||||
uint64_t nx, uint64_t ny, bool isCoo) {
|
||||
AffineMap xPerm, uint64_t ny) {
|
||||
auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
|
||||
Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
|
||||
Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
|
||||
@@ -176,20 +172,20 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
|
||||
builder.create<memref::StoreOp>(loc, vi, buffer, j);
|
||||
};
|
||||
|
||||
forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
|
||||
forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
|
||||
}
|
||||
|
||||
/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
|
||||
/// each pair is create via `compareBuilder`.
|
||||
static Value createInlinedCompareImplementation(
|
||||
OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
|
||||
bool isCoo,
|
||||
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
|
||||
uint64_t ny,
|
||||
function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
|
||||
compareBuilder) {
|
||||
Value result;
|
||||
auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
|
||||
bool isFirstDim = (k == 0);
|
||||
bool isLastDim = (k == nx - 1);
|
||||
bool isLastDim = (k == xPerm.getNumResults() - 1);
|
||||
Value val =
|
||||
compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
|
||||
if (isFirstDim) {
|
||||
@@ -202,7 +198,7 @@ static Value createInlinedCompareImplementation(
|
||||
}
|
||||
};
|
||||
|
||||
forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
|
||||
forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
|
||||
|
||||
builder.setInsertionPointAfterValue(result);
|
||||
return result;
|
||||
@@ -252,12 +248,12 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
|
||||
// else if (x2[2] != x2[j]))
|
||||
// and so on ...
|
||||
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
|
||||
ValueRange args, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, uint32_t nTrailingP = 0) {
|
||||
ValueRange args, AffineMap xPerm,
|
||||
uint64_t ny, uint32_t nTrailingP = 0) {
|
||||
// Compare functions don't use trailing parameters.
|
||||
(void)nTrailingP;
|
||||
assert(nTrailingP == 0);
|
||||
return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
|
||||
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
|
||||
createEqCompare);
|
||||
}
|
||||
|
||||
@@ -306,12 +302,12 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
|
||||
// else
|
||||
// and so on ...
|
||||
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
|
||||
ValueRange args, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, uint32_t nTrailingP = 0) {
|
||||
ValueRange args, AffineMap xPerm,
|
||||
uint64_t ny, uint32_t nTrailingP = 0) {
|
||||
// Compare functions don't use trailing parameters.
|
||||
(void)nTrailingP;
|
||||
assert(nTrailingP == 0);
|
||||
return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
|
||||
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
|
||||
createLessThanCompare);
|
||||
}
|
||||
|
||||
@@ -329,8 +325,8 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc,
|
||||
// return lo;
|
||||
//
|
||||
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
|
||||
func::FuncOp func, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, uint32_t nTrailingP = 0) {
|
||||
func::FuncOp func, AffineMap xPerm,
|
||||
uint64_t ny, uint32_t nTrailingP = 0) {
|
||||
// Binary search doesn't use trailing parameters.
|
||||
(void)nTrailingP;
|
||||
assert(nTrailingP == 0);
|
||||
@@ -368,11 +364,10 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
|
||||
|
||||
// Compare xs[p] < xs[mid].
|
||||
SmallVector<Value> compareOperands{p, mid};
|
||||
uint64_t numXBuffers = isCoo ? 1 : nx;
|
||||
constexpr uint64_t numXBuffers = 1;
|
||||
compareOperands.append(args.begin() + xStartIdx,
|
||||
args.begin() + xStartIdx + numXBuffers);
|
||||
Value cond2 =
|
||||
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
|
||||
Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
||||
// Update lo and hi for the WhileOp as follows:
|
||||
// if (xs[p] < xs[mid]))
|
||||
// hi = mid;
|
||||
@@ -392,10 +387,11 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
|
||||
/// while (xs[i] > xs[p]) i += step (step < 0)
|
||||
/// The routine returns i as well as a boolean value to indicate whether
|
||||
/// xs[i] == xs[p].
|
||||
static std::pair<Value, Value>
|
||||
createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
|
||||
ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, int step) {
|
||||
static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
|
||||
ModuleOp module,
|
||||
func::FuncOp func, ValueRange xs,
|
||||
Value i, Value p, AffineMap xPerm,
|
||||
uint64_t ny, int step) {
|
||||
Location loc = func.getLoc();
|
||||
scf::WhileOp whileOp =
|
||||
builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
|
||||
@@ -413,8 +409,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
|
||||
compareOperands.push_back(before->getArgument(0));
|
||||
}
|
||||
compareOperands.append(xs.begin(), xs.end());
|
||||
Value cond =
|
||||
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
|
||||
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
||||
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
|
||||
|
||||
Block *after =
|
||||
@@ -429,7 +424,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
|
||||
compareOperands[0] = i;
|
||||
compareOperands[1] = p;
|
||||
Value compareEq =
|
||||
createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo);
|
||||
createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
|
||||
|
||||
return std::make_pair(whileOp.getResult(0), compareEq);
|
||||
}
|
||||
@@ -438,67 +433,63 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
|
||||
/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
|
||||
/// right after the swap instructions.
|
||||
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
|
||||
uint64_t nx, uint64_t ny, bool isCoo,
|
||||
AffineMap xPerm, uint64_t ny,
|
||||
SmallVectorImpl<Value> &swapOperands,
|
||||
SmallVectorImpl<Value> &compareOperands,
|
||||
Value a, Value b) {
|
||||
// Compare(data[b], data[a]).
|
||||
compareOperands[0] = b;
|
||||
compareOperands[1] = a;
|
||||
Value cond =
|
||||
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
|
||||
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
||||
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
|
||||
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
||||
swapOperands[0] = b;
|
||||
swapOperands[1] = a;
|
||||
createSwap(builder, loc, swapOperands, nx, ny, isCoo);
|
||||
createSwap(builder, loc, swapOperands, xPerm, ny);
|
||||
return ifOp;
|
||||
}
|
||||
|
||||
/// Creates code to insert the 3rd element to a list of two sorted elements.
|
||||
static void createInsert3rd(OpBuilder &builder, Location loc, uint64_t nx,
|
||||
uint64_t ny, bool isCoo,
|
||||
SmallVectorImpl<Value> &swapOperands,
|
||||
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
|
||||
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
|
||||
SmallVectorImpl<Value> &compareOperands, Value v0,
|
||||
Value v1, Value v2) {
|
||||
scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
|
||||
swapOperands, compareOperands, v1, v2);
|
||||
createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands,
|
||||
compareOperands, v0, v1);
|
||||
scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
|
||||
compareOperands, v1, v2);
|
||||
createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
|
||||
v0, v1);
|
||||
builder.setInsertionPointAfter(ifOp);
|
||||
}
|
||||
|
||||
/// Creates code to sort 3 elements.
|
||||
static void createSort3(OpBuilder &builder, Location loc, uint64_t nx,
|
||||
uint64_t ny, bool isCoo,
|
||||
SmallVectorImpl<Value> &swapOperands,
|
||||
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
|
||||
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
|
||||
SmallVectorImpl<Value> &compareOperands, Value v0,
|
||||
Value v1, Value v2) {
|
||||
// Sort the first 2 elements.
|
||||
scf::IfOp ifOp1 = createCompareThenSwap(
|
||||
builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0, v1);
|
||||
scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
|
||||
compareOperands, v0, v1);
|
||||
builder.setInsertionPointAfter(ifOp1);
|
||||
|
||||
// Insert the 3th element.
|
||||
createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
|
||||
v0, v1, v2);
|
||||
createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
|
||||
v1, v2);
|
||||
}
|
||||
|
||||
/// Creates code to sort 5 elements.
|
||||
static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
|
||||
uint64_t ny, bool isCoo,
|
||||
SmallVectorImpl<Value> &swapOperands,
|
||||
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
|
||||
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
|
||||
SmallVectorImpl<Value> &compareOperands, Value v0,
|
||||
Value v1, Value v2, Value v3, Value v4) {
|
||||
// Sort the first 3 elements.
|
||||
createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0,
|
||||
v1, v2);
|
||||
createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
|
||||
v2);
|
||||
|
||||
auto insert4th = [&]() {
|
||||
scf::IfOp ifOp = createCompareThenSwap(
|
||||
builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v2, v3);
|
||||
createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
|
||||
v0, v1, v2);
|
||||
builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
|
||||
createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
|
||||
v1, v2);
|
||||
builder.setInsertionPointAfter(ifOp);
|
||||
};
|
||||
|
||||
@@ -506,8 +497,8 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
|
||||
insert4th();
|
||||
|
||||
// Insert the 5th element.
|
||||
scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
|
||||
swapOperands, compareOperands, v3, v4);
|
||||
scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
|
||||
compareOperands, v3, v4);
|
||||
insert4th();
|
||||
builder.setInsertionPointAfter(ifOp);
|
||||
}
|
||||
@@ -517,11 +508,10 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
|
||||
/// the number of values in range [lo, hi) is more than a threshold, we also
|
||||
/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
|
||||
static void createChoosePivot(OpBuilder &builder, ModuleOp module,
|
||||
func::FuncOp func, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, Value lo, Value hi, Value mi,
|
||||
ValueRange args) {
|
||||
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
||||
Value lo, Value hi, Value mi, ValueRange args) {
|
||||
SmallVector<Value> compareOperands{mi, lo};
|
||||
uint64_t numXBuffers = isCoo ? 1 : nx;
|
||||
constexpr uint64_t numXBuffers = 1;
|
||||
compareOperands.append(args.begin() + xStartIdx,
|
||||
args.begin() + xStartIdx + numXBuffers);
|
||||
SmallVector<Value> swapOperands{mi, lo};
|
||||
@@ -537,8 +527,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
|
||||
|
||||
// When len < 1000, choose pivot from median of 3 values.
|
||||
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
|
||||
createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo,
|
||||
mi, hi);
|
||||
createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
|
||||
hi);
|
||||
|
||||
// When len >= 1000, choose pivot from median of 5 values.
|
||||
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
|
||||
@@ -549,8 +539,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
|
||||
Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
|
||||
// Value b is the middle between [mi, hi].
|
||||
b = builder.create<arith::ShRUIOp>(loc, b, c1);
|
||||
createSort5(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo, a,
|
||||
mi, b, hi);
|
||||
createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
|
||||
b, hi);
|
||||
|
||||
builder.setInsertionPointAfter(lenIf);
|
||||
}
|
||||
@@ -586,8 +576,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
|
||||
// }
|
||||
// }
|
||||
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
|
||||
func::FuncOp func, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, uint32_t nTrailingP = 0) {
|
||||
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
||||
uint32_t nTrailingP = 0) {
|
||||
// Quick sort partition doesn't use trailing parameters.
|
||||
(void)nTrailingP;
|
||||
assert(nTrailingP == 0);
|
||||
@@ -606,7 +596,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
|
||||
|
||||
Value i = lo;
|
||||
Value j = builder.create<arith::SubIOp>(loc, hi, c1);
|
||||
createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
|
||||
createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
|
||||
Value trueVal = constantI1(builder, loc, true); // The value for while (true)
|
||||
SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
|
||||
SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
|
||||
@@ -628,14 +618,14 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
|
||||
j = after->getArgument(1);
|
||||
p = after->getArgument(2);
|
||||
|
||||
uint64_t numXBuffers = isCoo ? 1 : nx;
|
||||
constexpr uint64_t numXBuffers = 1;
|
||||
auto [iresult, iCompareEq] =
|
||||
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
|
||||
i, p, nx, ny, isCoo, 1);
|
||||
i, p, xPerm, ny, 1);
|
||||
i = iresult;
|
||||
auto [jresult, jCompareEq] =
|
||||
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
|
||||
j, p, nx, ny, isCoo, -1);
|
||||
j, p, xPerm, ny, -1);
|
||||
j = jresult;
|
||||
|
||||
// If i < j:
|
||||
@@ -645,7 +635,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
|
||||
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
||||
SmallVector<Value> swapOperands{i, j};
|
||||
swapOperands.append(args.begin() + xStartIdx, args.end());
|
||||
createSwap(builder, loc, swapOperands, nx, ny, isCoo);
|
||||
createSwap(builder, loc, swapOperands, xPerm, ny);
|
||||
// If the pivot is moved, update p with the new pivot.
|
||||
Value icond =
|
||||
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
|
||||
@@ -737,8 +727,8 @@ static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
|
||||
// }
|
||||
//
|
||||
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
|
||||
func::FuncOp func, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, uint32_t nTrailingP) {
|
||||
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
||||
uint32_t nTrailingP) {
|
||||
// The value n is passed in as a trailing parameter.
|
||||
assert(nTrailingP == 1);
|
||||
OpBuilder::InsertionGuard insertionGuard(builder);
|
||||
@@ -768,7 +758,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
|
||||
builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
|
||||
Value c1 = constantIndex(builder, loc, 1);
|
||||
SmallVector<Value> compareOperands{start, start};
|
||||
uint64_t numXBuffers = isCoo ? 1 : nx;
|
||||
constexpr uint64_t numXBuffers = 1;
|
||||
compareOperands.append(args.begin() + xStartIdx,
|
||||
args.begin() + xStartIdx + numXBuffers);
|
||||
|
||||
@@ -794,7 +784,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
|
||||
compareOperands[0] = lChildIdx;
|
||||
compareOperands[1] = rChildIdx;
|
||||
Value cond2 =
|
||||
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
|
||||
createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
||||
scf::IfOp if2 =
|
||||
builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
|
||||
builder.setInsertionPointToStart(&if2.getThenRegion().front());
|
||||
@@ -825,8 +815,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
|
||||
childIdx = before->getArgument(2);
|
||||
compareOperands[0] = start;
|
||||
compareOperands[1] = childIdx;
|
||||
Value cond =
|
||||
createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
|
||||
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
|
||||
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
|
||||
|
||||
// The after-region of the WhileOp.
|
||||
@@ -836,7 +825,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
|
||||
childIdx = after->getArgument(2);
|
||||
SmallVector<Value> swapOperands{start, childIdx};
|
||||
swapOperands.append(args.begin() + xStartIdx, args.end());
|
||||
createSwap(builder, loc, swapOperands, nx, ny, isCoo);
|
||||
createSwap(builder, loc, swapOperands, xPerm, ny);
|
||||
start = childIdx;
|
||||
Value cond2 =
|
||||
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
|
||||
@@ -869,8 +858,8 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
|
||||
// shiftdown(lo, lo, l-1)
|
||||
// }
|
||||
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
func::FuncOp func, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, uint32_t nTrailingP) {
|
||||
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
||||
uint32_t nTrailingP) {
|
||||
// Heap sort function doesn't have trailing parameters.
|
||||
(void)nTrailingP;
|
||||
assert(nTrailingP == 0);
|
||||
@@ -897,7 +886,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
shiftDownOperands.append(args.begin() + xStartIdx, args.end());
|
||||
shiftDownOperands.push_back(n);
|
||||
FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
|
||||
builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo,
|
||||
builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
|
||||
shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
|
||||
builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
|
||||
shiftDownOperands);
|
||||
@@ -912,7 +901,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
|
||||
SmallVector<Value> swapOperands{lo, loplm1};
|
||||
swapOperands.append(args.begin() + xStartIdx, args.end());
|
||||
createSwap(builder, loc, swapOperands, nx, ny, isCoo);
|
||||
createSwap(builder, loc, swapOperands, xPerm, ny);
|
||||
shiftDownOperands[1] = lo;
|
||||
shiftDownOperands[shiftDownOperands.size() - 1] =
|
||||
builder.create<arith::SubIOp>(loc, l, c1);
|
||||
@@ -928,7 +917,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
/// the bigger partition to be processed by the enclosed while-loop.
|
||||
static std::pair<Value, Value>
|
||||
createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
|
||||
ValueRange args, uint64_t nx, uint64_t ny, bool isCoo,
|
||||
ValueRange args, AffineMap xPerm, uint64_t ny,
|
||||
uint32_t nTrailingP) {
|
||||
MLIRContext *context = module.getContext();
|
||||
Location loc = func.getLoc();
|
||||
@@ -937,8 +926,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
|
||||
SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
|
||||
|
||||
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
|
||||
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
|
||||
ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
|
||||
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
|
||||
ny, args.drop_back(nTrailingP), createPartitionFunc);
|
||||
Value p = builder
|
||||
.create<func::CallOp>(loc, partitionFunc,
|
||||
TypeRange{IndexType::get(context)},
|
||||
@@ -1008,8 +997,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
|
||||
// }
|
||||
// }
|
||||
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
|
||||
func::FuncOp func, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, uint32_t nTrailingP) {
|
||||
func::FuncOp func, AffineMap xPerm,
|
||||
uint64_t ny, uint32_t nTrailingP) {
|
||||
// Stable sort function doesn't use trailing parameters.
|
||||
(void)nTrailingP;
|
||||
assert(nTrailingP == 0);
|
||||
@@ -1034,8 +1023,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
|
||||
SmallVector<Value> operands{lo, i};
|
||||
operands.append(args.begin() + xStartIdx, args.end());
|
||||
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
|
||||
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
|
||||
ny, isCoo, operands, createBinarySearchFunc);
|
||||
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
|
||||
xPerm, ny, operands, createBinarySearchFunc);
|
||||
Value p = builder
|
||||
.create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
|
||||
operands)
|
||||
@@ -1045,7 +1034,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
|
||||
operands[0] = operands[1] = i;
|
||||
SmallVector<Value> d;
|
||||
forEachIJPairInAllBuffers(
|
||||
builder, loc, operands, nx, ny, isCoo,
|
||||
builder, loc, operands, xPerm, ny,
|
||||
[&](uint64_t unused, Value i, Value unused2, Value buffer) {
|
||||
d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
|
||||
});
|
||||
@@ -1061,7 +1050,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
|
||||
operands[1] = imj;
|
||||
operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
|
||||
forEachIJPairInAllBuffers(
|
||||
builder, loc, operands, nx, ny, isCoo,
|
||||
builder, loc, operands, xPerm, ny,
|
||||
[&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
|
||||
Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
|
||||
builder.create<memref::StoreOp>(loc, t, buffer, imj);
|
||||
@@ -1071,7 +1060,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
|
||||
builder.setInsertionPointAfter(forOpJ);
|
||||
operands[0] = operands[1] = p;
|
||||
forEachIJPairInAllBuffers(
|
||||
builder, loc, operands, nx, ny, isCoo,
|
||||
builder, loc, operands, xPerm, ny,
|
||||
[&](uint64_t k, Value p, Value usused, Value buffer) {
|
||||
builder.create<memref::StoreOp>(loc, d[k], buffer, p);
|
||||
});
|
||||
@@ -1123,8 +1112,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
|
||||
// }
|
||||
//
|
||||
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
func::FuncOp func, uint64_t nx, uint64_t ny,
|
||||
bool isCoo, uint32_t nTrailingP) {
|
||||
func::FuncOp func, AffineMap xPerm, uint64_t ny,
|
||||
uint32_t nTrailingP) {
|
||||
assert(nTrailingP == 1 || nTrailingP == 0);
|
||||
bool isHybrid = (nTrailingP == 1);
|
||||
OpBuilder::InsertionGuard insertionGuard(builder);
|
||||
@@ -1173,7 +1162,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
// When len <= limit.
|
||||
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
|
||||
FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
|
||||
builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
|
||||
builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
|
||||
ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
|
||||
builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
|
||||
ValueRange(args).drop_back(nTrailingP));
|
||||
@@ -1193,7 +1182,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
// When depth exceeds limit.
|
||||
builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
|
||||
FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
|
||||
builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
|
||||
builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
|
||||
ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
|
||||
builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
|
||||
ValueRange(args).drop_back(nTrailingP));
|
||||
@@ -1203,7 +1192,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
|
||||
args.back() = depthLimit;
|
||||
std::tie(lo, hi) =
|
||||
createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
|
||||
createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
|
||||
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
|
||||
|
||||
builder.setInsertionPointAfter(depthIf);
|
||||
@@ -1216,7 +1205,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
hi = lenIf.getResult(1);
|
||||
} else {
|
||||
std::tie(lo, hi) =
|
||||
createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
|
||||
createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
|
||||
}
|
||||
|
||||
// New [lo, hi) for the next while-loop iteration.
|
||||
@@ -1229,9 +1218,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
|
||||
|
||||
/// Implements the rewriting for operator sort and sort_coo.
|
||||
template <typename OpTy>
|
||||
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
|
||||
uint64_t ny, bool isCoo,
|
||||
PatternRewriter &rewriter) {
|
||||
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
|
||||
uint64_t ny, PatternRewriter &rewriter) {
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
|
||||
|
||||
@@ -1285,8 +1273,8 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
|
||||
}
|
||||
|
||||
FlatSymbolRefAttr func =
|
||||
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
|
||||
ny, isCoo, operands, funcGenerator, nTrailingP);
|
||||
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
|
||||
xPerm, ny, operands, funcGenerator, nTrailingP);
|
||||
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
|
||||
return success();
|
||||
}
|
||||
@@ -1296,7 +1284,6 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
/// Sparse rewriting rule for the push_back operator.
|
||||
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
|
||||
public:
|
||||
@@ -1410,20 +1397,6 @@ private:
|
||||
bool enableBufferInitialization;
|
||||
};
|
||||
|
||||
/// Sparse rewriting rule for the sort operator.
|
||||
struct SortRewriter : public OpRewritePattern<SortOp> {
|
||||
public:
|
||||
using OpRewritePattern<SortOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SortOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Value> xys(op.getXs());
|
||||
xys.append(op.getYs().begin(), op.getYs().end());
|
||||
return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
|
||||
/*isCoo=*/false, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse rewriting rule for the sort_coo operator.
|
||||
struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
|
||||
public:
|
||||
@@ -1434,16 +1407,13 @@ public:
|
||||
SmallVector<Value> xys;
|
||||
xys.push_back(op.getXy());
|
||||
xys.append(op.getYs().begin(), op.getYs().end());
|
||||
uint64_t nx = 1;
|
||||
if (auto nxAttr = op.getNxAttr())
|
||||
nx = nxAttr.getInt();
|
||||
|
||||
auto xPerm = op.getPermMap();
|
||||
uint64_t ny = 0;
|
||||
if (auto nyAttr = op.getNyAttr())
|
||||
ny = nyAttr.getInt();
|
||||
|
||||
return matchAndRewriteSortOp(op, xys, nx, ny,
|
||||
/*isCoo=*/true, rewriter);
|
||||
return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1457,5 +1427,5 @@ void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
|
||||
bool enableBufferInitialization) {
|
||||
patterns.add<PushBackRewriter>(patterns.getContext(),
|
||||
enableBufferInitialization);
|
||||
patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
|
||||
patterns.add<SortCooRewriter>(patterns.getContext());
|
||||
}
|
||||
|
||||
@@ -890,8 +890,9 @@ public:
|
||||
// If the innermost level is ordered, we need to sort the coordinates
|
||||
// in the "added" array prior to applying the compression.
|
||||
if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
|
||||
rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{},
|
||||
SparseTensorSortKind::HybridQuickSort);
|
||||
rewriter.create<SortCooOp>(
|
||||
loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
|
||||
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
|
||||
// While performing the insertions, we also need to reset the elements
|
||||
// of the values/filled-switch by only iterating over the set elements,
|
||||
// to ensure that the runtime complexity remains proportional to the
|
||||
@@ -1486,9 +1487,10 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
scf::IfOp ifOp =
|
||||
rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
|
||||
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
||||
rewriter.create<SortCooOp>(
|
||||
loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(lvlRank),
|
||||
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
|
||||
auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
|
||||
rewriter.create<SortCooOp>(loc, nse, xs, ValueRange{ys}, xPerm,
|
||||
rewriter.getIndexAttr(0),
|
||||
SparseTensorSortKind::HybridQuickSort);
|
||||
rewriter.setInsertionPointAfter(ifOp);
|
||||
}
|
||||
|
||||
|
||||
@@ -207,7 +207,6 @@ struct SparseTensorCodegenPass
|
||||
ConversionTarget target(*ctx);
|
||||
// Most ops in the sparse dialect must go!
|
||||
target.addIllegalDialect<SparseTensorDialect>();
|
||||
target.addLegalOp<SortOp>();
|
||||
target.addLegalOp<SortCooOp>();
|
||||
target.addLegalOp<PushBackOp>();
|
||||
// Storage specifier outlives sparse tensor pipeline.
|
||||
|
||||
@@ -1206,29 +1206,23 @@ private:
|
||||
// Retrieve the values-array.
|
||||
Value y = genToValues(rewriter, loc, src);
|
||||
const auto encSrc = srcTp.getEncoding();
|
||||
// Sort the COO tensor so that its elements are ordered via increasing
|
||||
// coordinates for the storage ordering of the dst tensor. Use SortCoo
|
||||
// if the COO tensor has the same ordering as the dst tensor.
|
||||
if (dimRank > 1 && srcTp.hasSameDimToLvl(dstTp)) {
|
||||
Value xs = genToCoordinatesBuffer(rewriter, loc, src);
|
||||
rewriter.create<SortCooOp>(
|
||||
loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank),
|
||||
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
|
||||
} else {
|
||||
// Gather the coordinates-arrays in the dst tensor storage order.
|
||||
SmallVector<Value> xs(dstLvlRank);
|
||||
const Level srcLvlRank = srcTp.getLvlRank();
|
||||
for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) {
|
||||
// FIXME: `toOrigDim` is deprecated
|
||||
Dimension dim = toOrigDim(encSrc, srcLvl);
|
||||
// FIXME: `toStoredDim` is deprecated
|
||||
Level dstLvl = toStoredDim(encDst, dim);
|
||||
xs[dstLvl] =
|
||||
genToCoordinates(rewriter, loc, src, srcLvl, /*cooStart=*/0);
|
||||
}
|
||||
rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y},
|
||||
SparseTensorSortKind::HybridQuickSort);
|
||||
// Builds the dstLvl -> srcLvl permutation maps.
|
||||
SmallVector<AffineExpr> es(dstLvlRank);
|
||||
const Level srcLvlRank = srcTp.getLvlRank();
|
||||
for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) {
|
||||
// FIXME: `toOrigDim` is deprecated
|
||||
Dimension dim = toOrigDim(encSrc, srcLvl);
|
||||
// FIXME: `toStoredDim` is deprecated
|
||||
Level dstLvl = toStoredDim(encDst, dim);
|
||||
es[dstLvl] = rewriter.getAffineDimExpr(srcLvl);
|
||||
}
|
||||
auto xPerm = AffineMap::get(dstLvlRank, 0, es, rewriter.getContext());
|
||||
assert(xPerm.isPermutation()); // must be a permutation.
|
||||
|
||||
Value xs = genToCoordinatesBuffer(rewriter, loc, src);
|
||||
rewriter.create<SortCooOp>(loc, nnz, xs, ValueRange{y}, xPerm,
|
||||
rewriter.getIndexAttr(0),
|
||||
SparseTensorSortKind::HybridQuickSort);
|
||||
}
|
||||
|
||||
// For each element in the COO tensor, insert the element to the dst tensor.
|
||||
|
||||
@@ -75,123 +75,64 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index
|
||||
// CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index
|
||||
// CHECK-LABEL: func.func @sparse_sort_1d2v_quick
|
||||
func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
|
||||
-> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
|
||||
sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
|
||||
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Only check the generated supporting function now. We have integration test
|
||||
// to verify correctness of the generated code.
|
||||
//
|
||||
// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
|
||||
// CHECK-LABEL: func.func @sparse_sort_3d_quick
|
||||
func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
|
||||
sparse_tensor.sort quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
|
||||
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Only check the generated supporting function now. We have integration test
|
||||
// to verify correctness of the generated code.
|
||||
//
|
||||
// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
|
||||
// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
|
||||
// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
|
||||
// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: i64) {
|
||||
// CHECK-LABEL: func.func @sparse_sort_3d_hybrid
|
||||
func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
|
||||
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
|
||||
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
|
||||
|
||||
// Only check the generated supporting functions. We have integration test to
|
||||
// verify correctness of the generated code.
|
||||
//
|
||||
// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
|
||||
// CHECK-LABEL: func.func @sparse_sort_3d_stable
|
||||
func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
|
||||
sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
|
||||
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Only check the generated supporting functions. We have integration test to
|
||||
// verify correctness of the generated code.
|
||||
//
|
||||
// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
|
||||
// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
|
||||
// CHECK-LABEL: func.func @sparse_sort_3d_heap
|
||||
func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
|
||||
sparse_tensor.sort heap_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
|
||||
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Only check the generated supporting functions. We have integration test to
|
||||
// verify correctness of the generated code.
|
||||
//
|
||||
// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-LABEL: func.func @sparse_sort_coo_quick
|
||||
func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
|
||||
sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
|
||||
sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
|
||||
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
|
||||
|
||||
// Only check the generated supporting functions. We have integration test to
|
||||
// verify correctness of the generated code.
|
||||
//
|
||||
// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
|
||||
// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
|
||||
// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
|
||||
// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_hybrid_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
|
||||
// CHECK-LABEL: func.func @sparse_sort_coo_hybrid
|
||||
func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
|
||||
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
|
||||
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
|
||||
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
|
||||
|
||||
// Only check the generated supporting functions. We have integration test to
|
||||
// verify correctness of the generated code.
|
||||
//
|
||||
// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
|
||||
// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-LABEL: func.func @sparse_sort_coo_stable
|
||||
func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
|
||||
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
|
||||
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
|
||||
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
|
||||
|
||||
// Only check the generated supporting functions. We have integration test to
|
||||
// verify correctness of the generated code.
|
||||
//
|
||||
// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
|
||||
// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
|
||||
// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
|
||||
// CHECK-LABEL: func.func @sparse_sort_coo_heap
|
||||
func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
|
||||
sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
|
||||
sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
|
||||
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
|
||||
}
|
||||
|
||||
@@ -436,7 +436,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
|
||||
// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
|
||||
// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
|
||||
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
|
||||
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
|
||||
// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
|
||||
// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
|
||||
// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
|
||||
@@ -484,7 +484,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
|
||||
// CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64
|
||||
// CHECK: %[[A12:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[A13:.*]] = arith.constant 0 : index
|
||||
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
|
||||
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
|
||||
// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
|
||||
// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
|
||||
// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
|
||||
@@ -712,7 +712,7 @@ func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) ->
|
||||
// CHECK: %[[A33:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A5]], %[[A32]], %[[A14]], %[[A15]])
|
||||
// CHECK: %[[A34:.*]] = arith.cmpi eq, %[[A33]], %[[A1]] : i1
|
||||
// CHECK: scf.if %[[A34]] {
|
||||
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {nx = 2 : index, ny = 0 : index} : memref<?xindex> jointly memref<?xf32>
|
||||
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref<?xindex> jointly memref<?xf32>
|
||||
// CHECK: }
|
||||
// CHECK: memref.store %[[A10]], %[[A27]]{{\[}}%[[A2]]] : memref<?xindex>
|
||||
// CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A30]] crd_mem_sz at 0 with %[[A11]]
|
||||
|
||||
@@ -178,7 +178,7 @@ func.func @sparse_convert_singleton(%arg0: tensor<?xf32, #SparseSingleton64>) ->
|
||||
// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_17:.*]] hasInserts : tensor<?x?x?xf32, #{{.*}}>>
|
||||
// CHECK-RWT: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xf32>
|
||||
// CHECK-RWT: %[[VAL_19:.*]] = sparse_tensor.coordinates_buffer %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xindex>
|
||||
// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {nx = 3 : index, ny = 0 : index}
|
||||
// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {ny = 0 : index, perm_map = #map}
|
||||
// CHECK-RWT: %[[VAL_20:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) size_hint=%[[VAL_7]]
|
||||
// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.foreach in %[[VAL_16]] init(%[[VAL_20]])
|
||||
// CHECK-RWT: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: tensor<?x?x?xf32, #{{.*}}>>):
|
||||
|
||||
@@ -790,60 +790,51 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: a test case with empty xs doesn't work due to some parser issues.
|
||||
|
||||
func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
|
||||
// expected-error@+1 {{operand #1 must be 1D memref of integer or index values}}
|
||||
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1: memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
|
||||
%i20 = arith.constant 20 : index
|
||||
// expected-error@+1 {{xs and ys need to have a dimension >= n: 10 < 20}}
|
||||
sparse_tensor.sort insertion_sort_stable %i20, %arg0 : memref<10xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) {
|
||||
// expected-error@+1 {{mismatch xs element types}}
|
||||
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
#MAP = affine_map<(i,j) -> (i,j)>
|
||||
|
||||
func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
|
||||
// expected-error@+1 {{operand #1 must be 1D memref of integer or index values}}
|
||||
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1: memref<?xf32>
|
||||
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 {perm_map = #MAP} : memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#MAP = affine_map<(i,j) -> (i,j)>
|
||||
|
||||
func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
|
||||
%i20 = arith.constant 20 : index
|
||||
// expected-error@+1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}}
|
||||
sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
|
||||
// expected-error@+1 {{Expected dimension(xy) >= n * (rank(perm_map) + ny) got 50 < 60}}
|
||||
sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {perm_map = #MAP, ny = 1 : index} : memref<50xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#MAP = affine_map<(i,j) -> (i,j)>
|
||||
|
||||
func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) {
|
||||
%i20 = arith.constant 20 : index
|
||||
// expected-error@+1 {{Expected dimension(y) >= n got 10 < 20}}
|
||||
sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
|
||||
sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {perm_map = #MAP, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#NON_PERM_MAP = affine_map<(i,j) -> (i,i)>
|
||||
|
||||
func.func @sparse_sort_coo_no_perm(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
|
||||
// expected-error@+1 {{Expected a permutation map, got (d0, d1) -> (d0, d0)}}
|
||||
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #NON_PERM_MAP, ny = 1 : index}: memref<?xindex>
|
||||
return %arg1 : memref<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
|
||||
|
||||
func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {
|
||||
|
||||
@@ -612,79 +612,29 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sparse_sort_1d0v(
|
||||
// CHECK-SAME: %[[A:.*]]: index,
|
||||
// CHECK-SAME: %[[B:.*]]: memref<?xindex>)
|
||||
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] : memref<?xindex>
|
||||
// CHECK: return %[[B]]
|
||||
func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
|
||||
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 : memref<?xindex>
|
||||
return %arg1 : memref<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sparse_sort_1d2v(
|
||||
// CHECK-SAME: %[[A:.*]]: index,
|
||||
// CHECK-SAME: %[[B:.*]]: memref<20xindex>,
|
||||
// CHECK-SAME: %[[C:.*]]: memref<10xindex>,
|
||||
// CHECK-SAME: %[[D:.*]]: memref<?xf32>)
|
||||
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
|
||||
// CHECK: return %[[B]], %[[C]], %[[D]]
|
||||
func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref<?xf32>) -> (memref<20xindex>, memref<10xindex>, memref<?xf32>) {
|
||||
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
|
||||
return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sparse_sort_2d1v(
|
||||
// CHECK-SAME: %[[A:.*]]: index,
|
||||
// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
|
||||
// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
|
||||
// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
|
||||
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
|
||||
// CHECK: return %[[B]], %[[C]], %[[D]]
|
||||
func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
|
||||
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
|
||||
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sparse_sort_stable(
|
||||
// CHECK-SAME: %[[A:.*]]: index,
|
||||
// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
|
||||
// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
|
||||
// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
|
||||
// CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
|
||||
// CHECK: return %[[B]], %[[C]], %[[D]]
|
||||
func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
|
||||
sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
|
||||
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
#ID_MAP = affine_map<(i,j) -> (i,j)>
|
||||
|
||||
// CHECK-LABEL: func @sparse_sort_coo(
|
||||
// CHECK-SAME: %[[A:.*]]: index,
|
||||
// CHECK-SAME: %[[B:.*]]: memref<?xindex>)
|
||||
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {nx = 2 : index, ny = 1 : index} : memref<?xindex>
|
||||
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {ny = 1 : index, perm_map = #{{.*}}} : memref<?xindex>
|
||||
// CHECK: return %[[B]]
|
||||
func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
|
||||
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {nx = 2 : index, ny = 1 : index}: memref<?xindex>
|
||||
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xindex>
|
||||
return %arg1 : memref<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#ID_MAP = affine_map<(i,j) -> (i,j)>
|
||||
|
||||
// CHECK-LABEL: func @sparse_sort_coo_stable(
|
||||
// CHECK-SAME: %[[A:.*]]: index,
|
||||
// CHECK-SAME: %[[B:.*]]: memref<?xi64>,
|
||||
// CHECK-SAME: %[[C:.*]]: memref<?xf32>)
|
||||
// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {nx = 2 : index, ny = 1 : index}
|
||||
// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {ny = 1 : index, perm_map = #{{.*}}}
|
||||
// CHECK: return %[[B]], %[[C]]
|
||||
func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: memref<?xf32>) -> (memref<?xi64>, memref<?xf32>) {
|
||||
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {nx = 2 : index, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
|
||||
sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
|
||||
return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@
|
||||
// CHECK: } {"Emitted from" = "linalg.generic"}
|
||||
// CHECK: scf.yield %[[VAL_64:.*]] : index
|
||||
// CHECK: } {"Emitted from" = "linalg.generic"}
|
||||
// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] : memref<?xindex>
|
||||
// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]]
|
||||
// CHECK: %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
|
||||
// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex>
|
||||
// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64>
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
|
||||
//
|
||||
// Set-up that's shared across all tests in this directory. In principle, this
|
||||
// config could be moved to lit.local.cfg. However, there are downstream users that
|
||||
// do not use these LIT config files. Hence why this is kept inline.
|
||||
//
|
||||
// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
|
||||
// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
|
||||
// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
|
||||
// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
|
||||
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
|
||||
// DEFINE: %{run_opts} = -e entry -entry-point-result=void
|
||||
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
|
||||
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
|
||||
//
|
||||
// DEFINE: %{env} =
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
||||
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
|
||||
// RUN: %{compile} | %{run} | FileCheck %s
|
||||
//
|
||||
// Do the same run, but now with vectorization.
|
||||
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
|
||||
// RUN: %{compile} | %{run} | FileCheck %s
|
||||
//
|
||||
// Do the same run, but now with VLA vectorization.
|
||||
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
|
||||
|
||||
module {
|
||||
func.func private @printMemref1dI32(%ptr : memref<?xi32>) attributes { llvm.emit_c_interface }
|
||||
|
||||
// Stores 5 values to the memref buffer.
|
||||
func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
|
||||
%v3: i32, %v4: i32) -> () {
|
||||
%i0 = arith.constant 0 : index
|
||||
%i1 = arith.constant 1 : index
|
||||
%i2 = arith.constant 2 : index
|
||||
%i3 = arith.constant 3 : index
|
||||
%i4 = arith.constant 4 : index
|
||||
memref.store %v0, %b[%i0] : memref<?xi32>
|
||||
memref.store %v1, %b[%i1] : memref<?xi32>
|
||||
memref.store %v2, %b[%i2] : memref<?xi32>
|
||||
memref.store %v3, %b[%i3] : memref<?xi32>
|
||||
memref.store %v4, %b[%i4] : memref<?xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// The main driver.
|
||||
func.func @entry() {
|
||||
%c0 = arith.constant 0 : i32
|
||||
%c1 = arith.constant 1 : i32
|
||||
%c2 = arith.constant 2 : i32
|
||||
%c3 = arith.constant 3 : i32
|
||||
%c4 = arith.constant 4 : i32
|
||||
%c5 = arith.constant 5 : i32
|
||||
%c6 = arith.constant 6 : i32
|
||||
%c7 = arith.constant 7 : i32
|
||||
%c8 = arith.constant 8 : i32
|
||||
%c9 = arith.constant 9 : i32
|
||||
%c10 = arith.constant 10 : i32
|
||||
%c100 = arith.constant 100 : i32
|
||||
|
||||
%i0 = arith.constant 0 : index
|
||||
%i4 = arith.constant 4 : index
|
||||
%i5 = arith.constant 5 : index
|
||||
|
||||
// Prepare a buffer.
|
||||
%x0s = memref.alloc() : memref<5xi32>
|
||||
%x0 = memref.cast %x0s : memref<5xi32> to memref<?xi32>
|
||||
call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
|
||||
// Sort 0 elements.
|
||||
// Quick sort.
|
||||
// CHECK: [10, 2, 0, 5, 1]
|
||||
sparse_tensor.sort quick_sort %i0, %x0 : memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
// Stable sort.
|
||||
// CHECK: [10, 2, 0, 5, 1]
|
||||
sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
// Heap sort.
|
||||
// CHECK: [10, 2, 0, 5, 1]
|
||||
sparse_tensor.sort heap_sort %i0, %x0 : memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
// Hybrid sort.
|
||||
// CHECK: [10, 2, 0, 5, 1]
|
||||
sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
|
||||
// Sort the first 4 elements, with the last valid value untouched.
|
||||
// Quick sort.
|
||||
// CHECK: [0, 2, 5, 10, 1]
|
||||
sparse_tensor.sort quick_sort %i4, %x0 : memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
// Stable sort.
|
||||
// CHECK: [0, 2, 5, 10, 1]
|
||||
call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
// Heap sort.
|
||||
// CHECK: [0, 2, 5, 10, 1]
|
||||
call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
sparse_tensor.sort heap_sort %i4, %x0 : memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
// Hybrid sort.
|
||||
// CHECK: [0, 2, 5, 10, 1]
|
||||
sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
|
||||
// Prepare more buffers of different dimensions.
|
||||
%x1s = memref.alloc() : memref<10xi32>
|
||||
%x1 = memref.cast %x1s : memref<10xi32> to memref<?xi32>
|
||||
%x2s = memref.alloc() : memref<6xi32>
|
||||
%x2 = memref.cast %x2s : memref<6xi32> to memref<?xi32>
|
||||
%y0s = memref.alloc() : memref<7xi32>
|
||||
%y0 = memref.cast %y0s : memref<7xi32> to memref<?xi32>
|
||||
|
||||
// Sort "parallel arrays".
|
||||
// CHECK: [1, 1, 2, 5, 10]
|
||||
// CHECK: [3, 3, 1, 10, 1
|
||||
// CHECK: [9, 9, 4, 7, 2
|
||||
// CHECK: [7, 8, 10, 9, 6
|
||||
call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
sparse_tensor.sort quick_sort %i5, %x0, %x1, %x2 jointly %y0
|
||||
: memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
|
||||
call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
|
||||
call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
|
||||
// Stable sort.
|
||||
// CHECK: [1, 1, 2, 5, 10]
|
||||
// CHECK: [3, 3, 1, 10, 1
|
||||
// CHECK: [9, 9, 4, 7, 2
|
||||
// CHECK: [8, 7, 10, 9, 6
|
||||
call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
sparse_tensor.sort insertion_sort_stable %i5, %x0, %x1, %x2 jointly %y0
|
||||
: memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
|
||||
call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
|
||||
call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
|
||||
// Heap sort.
|
||||
// CHECK: [1, 1, 2, 5, 10]
|
||||
// CHECK: [3, 3, 1, 10, 1
|
||||
// CHECK: [9, 9, 4, 7, 2
|
||||
// CHECK: [7, 8, 10, 9, 6
|
||||
call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
sparse_tensor.sort heap_sort %i5, %x0, %x1, %x2 jointly %y0
|
||||
: memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
|
||||
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
|
||||
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
|
||||
call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
|
||||
call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
|
||||
|
||||
// Release the buffers.
|
||||
memref.dealloc %x0 : memref<?xi32>
|
||||
memref.dealloc %x1 : memref<?xi32>
|
||||
memref.dealloc %x2 : memref<?xi32>
|
||||
memref.dealloc %y0 : memref<?xi32>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,8 @@
|
||||
// Do the same run, but now with VLA vectorization.
|
||||
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
|
||||
|
||||
#ID_MAP = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
|
||||
|
||||
module {
|
||||
// Stores 5 values to the memref buffer.
|
||||
func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
|
||||
@@ -94,11 +96,11 @@ module {
|
||||
%y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
|
||||
|
||||
// Sort "parallel arrays".
|
||||
// CHECK: ( 1, 1, 3, 3, 10 )
|
||||
// CHECK: ( 2, 10, 1, 1, 5 )
|
||||
// CHECK: ( 4, 2, 9, 9, 7 )
|
||||
// CHECK: ( 10, 6, 7, 8, 9 )
|
||||
// CHECK: ( 7, 5, 7, 4, 9 )
|
||||
// CHECK: ( 1, 1, 2, 5, 10 )
|
||||
// CHECK: ( 9, 9, 4, 7, 2 )
|
||||
// CHECK: ( 3, 3, 1, 10, 1 )
|
||||
// CHECK: ( 7, 8, 10, 9, 6 )
|
||||
// CHECK: ( 7, 4, 7, 9, 5 )
|
||||
call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
|
||||
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
|
||||
@@ -109,24 +111,25 @@ module {
|
||||
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
|
||||
sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
|
||||
: memref<?xi32> jointly memref<?xi32>
|
||||
%x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x0v : vector<5xi32>
|
||||
// Dumps memory in the same order as the perm_map such that the output is ordered.
|
||||
%x1v = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x1v : vector<5xi32>
|
||||
%x2v = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x2v : vector<5xi32>
|
||||
%x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x0v : vector<5xi32>
|
||||
%y0v = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %y0v : vector<5xi32>
|
||||
%y1v = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
|
||||
vector.print %y1v : vector<5xi32>
|
||||
// Stable sort.
|
||||
// CHECK: ( 1, 1, 3, 3, 10 )
|
||||
// CHECK: ( 2, 10, 1, 1, 5 )
|
||||
// CHECK: ( 4, 2, 9, 9, 7 )
|
||||
// CHECK: ( 10, 6, 8, 7, 9 )
|
||||
// CHECK: ( 7, 5, 4, 7, 9 )
|
||||
// CHECK: ( 1, 1, 2, 5, 10 )
|
||||
// CHECK: ( 9, 9, 4, 7, 2 )
|
||||
// CHECK: ( 3, 3, 1, 10, 1 )
|
||||
// CHECK: ( 8, 7, 10, 9, 6 )
|
||||
// CHECK: ( 4, 7, 7, 9, 5 )
|
||||
call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
|
||||
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
|
||||
@@ -137,24 +140,24 @@ module {
|
||||
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
|
||||
sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
|
||||
: memref<?xi32> jointly memref<?xi32>
|
||||
%x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x0v2 : vector<5xi32>
|
||||
%x1v2 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x1v2 : vector<5xi32>
|
||||
%x2v2 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x2v2 : vector<5xi32>
|
||||
%x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x0v2 : vector<5xi32>
|
||||
%y0v2 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %y0v2 : vector<5xi32>
|
||||
%y1v2 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
|
||||
vector.print %y1v2 : vector<5xi32>
|
||||
// Heap sort.
|
||||
// CHECK: ( 1, 1, 3, 3, 10 )
|
||||
// CHECK: ( 2, 10, 1, 1, 5 )
|
||||
// CHECK: ( 4, 2, 9, 9, 7 )
|
||||
// CHECK: ( 10, 6, 8, 7, 9 )
|
||||
// CHECK: ( 7, 5, 4, 7, 9 )
|
||||
// CHECK: ( 1, 1, 2, 5, 10 )
|
||||
// CHECK: ( 9, 9, 4, 7, 2 )
|
||||
// CHECK: ( 3, 3, 1, 10, 1 )
|
||||
// CHECK: ( 7, 8, 10, 9, 6 )
|
||||
// CHECK: ( 7, 4, 7, 9, 5 )
|
||||
call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
|
||||
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
|
||||
@@ -165,14 +168,14 @@ module {
|
||||
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
|
||||
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
|
||||
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
|
||||
sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
|
||||
sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
|
||||
: memref<?xi32> jointly memref<?xi32>
|
||||
%x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x0v3 : vector<5xi32>
|
||||
%x1v3 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x1v3 : vector<5xi32>
|
||||
%x2v3 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x2v3 : vector<5xi32>
|
||||
%x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %x0v3 : vector<5xi32>
|
||||
%y0v3 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
|
||||
vector.print %y0v3 : vector<5xi32>
|
||||
%y1v3 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
|
||||
|
||||
Reference in New Issue
Block a user