[mlir][sparse] do not ignore ordering for "dense" tensor linked with sparse type

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D97795
This commit is contained in:
Aart Bik
2021-03-02 12:17:13 -08:00
parent 16005fd979
commit 5b333d3449

View File

@@ -357,6 +357,14 @@ static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
}
}
/// Returns true if tensor was set up with sparse storage scheme.
static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
if (tensor < op.getNumInputs())
return isa_and_nonnull<linalg::SparseTensorFromPointerOp>(
op.getInput(tensor).getDefiningOp());
return false;
}
/// A DFS helper to compute a topological sort. Note that recursion is
/// bounded by the number of implicit loops, which is always small.
/// Returns false when a cycle is detected.
@@ -394,7 +402,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
auto map = op.getIndexingMap(t);
assert(map.getNumDims() == n);
// Skip dense tensor constraints when sparse only is requested.
if (sparseOnly && !merger.isSparseTensor(t))
if (sparseOnly && !merger.isSparseTensor(t) && !linkedSparse(op, t))
continue;
// At the moment, we take the index variables in the tensor access
// expression in the order in which they appear (conceptually a
@@ -513,14 +521,6 @@ static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
llvm_unreachable("unexpected SparseIntType");
}
/// Returns true if tensor was set up with sparse storage scheme.
static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
if (tensor < op.getNumInputs())
return isa_and_nonnull<linalg::SparseTensorFromPointerOp>(
op.getInput(tensor).getDefiningOp());
return false;
}
/// Generates buffer for the output tensor.
static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
linalg::GenericOp op, MemRefType denseTp,
@@ -1004,7 +1004,7 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
if (needsUniv) {
types.push_back(indexType);
assert(codegen.loops[idx].getType().isa<IndexType>() &&
"type_mismatch for universal index");
"type mismatch for universal index");
operands.push_back(codegen.loops[idx]);
}
Location loc = op.getLoc();