mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 10:08:59 +08:00
[mlir][sparse] support iteration over compressed-hi dimension level in loop emitter
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D148668
This commit is contained in:
@@ -388,7 +388,7 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
|
||||
!highs[t][l]);
|
||||
const auto lvlTp = lvlTypes[t][l];
|
||||
// Handle sparse storage schemes.
|
||||
if (isCompressedDLT(lvlTp)) {
|
||||
if (isCompressedDLT(lvlTp) || isCompressedWithHiDLT(lvlTp)) {
|
||||
// Generate sparse primitives to obtain positions and coordinates.
|
||||
positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
|
||||
coordinatesBuffers[t][l] =
|
||||
@@ -557,6 +557,7 @@ Operation *LoopEmitter::emitForLoopOverTensorAtLvl(
|
||||
OpBuilder &builder, Location loc, TensorId tid, Level dstLvl, Value lo,
|
||||
Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
|
||||
bool isSparseCond = isCompressedDLT(lvlTypes[tid][dstLvl]) ||
|
||||
isCompressedWithHiDLT(lvlTypes[tid][dstLvl]) ||
|
||||
isSingletonDLT(lvlTypes[tid][dstLvl]);
|
||||
|
||||
const auto reassoc = getCollapseReassociation(tid, dstLvl);
|
||||
@@ -695,7 +696,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
|
||||
auto lvlType = lvlTypes[t][l];
|
||||
// Must be a recognizable DLT.
|
||||
assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
|
||||
isSingletonDLT(lvlType));
|
||||
isCompressedWithHiDLT(lvlType) || isSingletonDLT(lvlType));
|
||||
|
||||
// This is a slice-driven loop on sparse level.
|
||||
if (!dependentLvlMap[t][l].empty() && !isDenseDLT(lvlType)) {
|
||||
@@ -901,7 +902,8 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
|
||||
// TODO: support coiteration with slice driven tensors.
|
||||
const auto lvlTp = lvlTypes[tid][lvl];
|
||||
assert(dependentLvlMap[tid][lvl].empty() && "TODO: not yet implemented");
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
|
||||
isCompressedWithHiDLT(lvlTp)) {
|
||||
const auto reassoc = getCollapseReassociation(tid, lvl);
|
||||
for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
|
||||
if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) {
|
||||
@@ -941,7 +943,8 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
|
||||
for (auto [t, lvl] : llvm::zip(tids, lvls)) {
|
||||
const TensorId tid = t; // Why `t` can not be captured by lambda?
|
||||
const auto lvlTp = lvlTypes[tid][lvl];
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
|
||||
isCompressedWithHiDLT(lvlTp)) {
|
||||
const auto reassoc = getCollapseReassociation(tid, lvl);
|
||||
assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
|
||||
for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
|
||||
@@ -974,7 +977,8 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
|
||||
for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
|
||||
// Prepares for next level.
|
||||
const auto lvlTp = lvlTypes[tid][lvl];
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
|
||||
isCompressedWithHiDLT(lvlTp)) {
|
||||
coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl);
|
||||
if (isSparseSlices[tid]) {
|
||||
auto [trans, pred] =
|
||||
@@ -1023,7 +1027,8 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
|
||||
if (!needsUniv) {
|
||||
for (auto [tid, lvl] : llvm::zip(tids, lvls)) {
|
||||
const auto lvlTp = lvlTypes[tid][lvl];
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
|
||||
isCompressedWithHiDLT(lvlTp)) {
|
||||
const auto crd = coords[tid][lvl];
|
||||
if (min) {
|
||||
Value cmp = CMPI(ult, coords[tid][lvl], min);
|
||||
@@ -1117,12 +1122,14 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
|
||||
// Either the first level, or the previous level has been set.
|
||||
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
|
||||
assert(srcLvl == 0 || posits[tid][srcLvl - 1]);
|
||||
if (!isCompressedDLT(lvlTp) && !isSingletonDLT(lvlTp))
|
||||
if (isDenseDLT(lvlTp))
|
||||
continue;
|
||||
if (isCompressedDLT(lvlTp)) {
|
||||
if (isCompressedDLT(lvlTp) || isCompressedWithHiDLT(lvlTp)) {
|
||||
const Value mem = positionsBuffers[tid][srcLvl];
|
||||
|
||||
const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1];
|
||||
Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1];
|
||||
if (isCompressedWithHiDLT(lvlTp))
|
||||
pLo = builder.create<arith::MulIOp>(loc, pLo, C_IDX(2));
|
||||
posits[tid][srcLvl] = genIndexLoad(builder, loc, mem, pLo);
|
||||
|
||||
const Value pHi = ADDI(pLo, c1);
|
||||
@@ -1321,7 +1328,8 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
|
||||
Value one = C_IDX(1);
|
||||
for (auto [tid, dstLvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) {
|
||||
const auto lvlTp = lvlTypes[tid][dstLvl];
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) {
|
||||
if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) ||
|
||||
isCompressedWithHiDLT(lvlTp)) {
|
||||
const auto reassoc = getCollapseReassociation(tid, dstLvl);
|
||||
assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
|
||||
for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) {
|
||||
|
||||
@@ -532,6 +532,8 @@ static void genEndInsert(OpBuilder &builder, Location loc,
|
||||
const Level lvlRank = stt.getLvlRank();
|
||||
for (Level l = 0; l < lvlRank; l++) {
|
||||
const auto dlt = stt.getLvlType(l);
|
||||
if (isCompressedWithHiDLT(dlt))
|
||||
llvm_unreachable("TODO: Not yet implemented");
|
||||
if (isCompressedDLT(dlt)) {
|
||||
// Compressed dimensions need a position cleanup for all entries
|
||||
// that were not visited during the insertion pass.
|
||||
|
||||
@@ -145,7 +145,7 @@ void sparse_tensor::foreachFieldInSparseTensor(
|
||||
// As a result, the compound type can be constructed directly in the given
|
||||
// order.
|
||||
const auto dlt = lvlTypes[l];
|
||||
if (isCompressedDLT(dlt)) {
|
||||
if (isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)) {
|
||||
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
|
||||
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
|
||||
} else if (isSingletonDLT(dlt)) {
|
||||
|
||||
@@ -794,7 +794,8 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
|
||||
const TensorId tid = env.makeTensorId(t.getOperandNumber());
|
||||
for (LoopId i = 0; i < numLoops; i++) {
|
||||
const auto dltI = env.dlt(tid, i);
|
||||
if (isCompressedDLT(dltI) || isSingletonDLT(dltI)) {
|
||||
if (isCompressedDLT(dltI) || isCompressedWithHiDLT(dltI) ||
|
||||
isSingletonDLT(dltI)) {
|
||||
for (LoopId j = 0; j < numLoops; j++)
|
||||
if (isUndefDLT(env.dlt(tid, j))) {
|
||||
adjM[i][j] = true;
|
||||
@@ -1410,7 +1411,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
|
||||
DimLevelType dlt, bool /*unused*/) {
|
||||
assert(ldx == env.merger().loop(b));
|
||||
Value clause;
|
||||
if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) {
|
||||
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt)) {
|
||||
assert(lvl.has_value());
|
||||
const Value crd = env.emitter().getCoords()[tid][*lvl];
|
||||
const Value lvar = env.getLoopVar(ldx);
|
||||
|
||||
@@ -418,7 +418,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
|
||||
// Slice on dense level has `locate` property as well, and can be optimized.
|
||||
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
|
||||
const auto dlt = getDimLevelType(b);
|
||||
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) {
|
||||
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) && !isCompressedWithHiDLT(dlt)) {
|
||||
if (reset)
|
||||
simple.reset(b);
|
||||
reset = true;
|
||||
@@ -585,7 +585,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
|
||||
bool Merger::hasAnySparse(const BitVector &bits) const {
|
||||
for (TensorLoopId b : bits.set_bits()) {
|
||||
const auto dlt = getDimLevelType(b);
|
||||
if (isCompressedDLT(dlt) || isSingletonDLT(dlt))
|
||||
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt))
|
||||
return true;
|
||||
}
|
||||
return hasSparseIdxReduction(bits);
|
||||
|
||||
@@ -138,4 +138,34 @@ func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
|
||||
"test.use" (%v) : (f64) -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
#BCOO = #sparse_tensor.encoding<{
|
||||
dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ],
|
||||
}>
|
||||
|
||||
// CHECK-LABEL: func.func @foreach_bcoo(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>>) {
|
||||
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index
|
||||
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
|
||||
// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
|
||||
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : index
|
||||
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index
|
||||
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref<?xindex>
|
||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_10]] step %[[VAL_3]] {
|
||||
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xf64>
|
||||
// CHECK: "test.use"(%[[VAL_12]]) : (f64) -> ()
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return
|
||||
func.func @foreach_bcoo(%A: tensor<4x4x4xf64, #BCOO>) {
|
||||
sparse_tensor.foreach in %A : tensor<4x4x4xf64, #BCOO> do {
|
||||
^bb0(%1: index, %2: index, %3: index, %v: f64) :
|
||||
"test.use" (%v) : (f64) -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user