mirror of
https://github.com/intel/llvm.git
synced 2026-02-01 17:07:36 +08:00
[mlir][sparse] do not ignore ordering for "dense" tensor linked with sparse type
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D97795
This commit is contained in:
@@ -357,6 +357,14 @@ static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if tensor was set up with sparse storage scheme.
|
||||
static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
|
||||
if (tensor < op.getNumInputs())
|
||||
return isa_and_nonnull<linalg::SparseTensorFromPointerOp>(
|
||||
op.getInput(tensor).getDefiningOp());
|
||||
return false;
|
||||
}
|
||||
|
||||
/// A DFS helper to compute a topological sort. Note that recursion is
|
||||
/// bounded by the number of implicit loops, which is always small.
|
||||
/// Returns false when a cycle is detected.
|
||||
@@ -394,7 +402,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
|
||||
auto map = op.getIndexingMap(t);
|
||||
assert(map.getNumDims() == n);
|
||||
// Skip dense tensor constraints when sparse only is requested.
|
||||
if (sparseOnly && !merger.isSparseTensor(t))
|
||||
if (sparseOnly && !merger.isSparseTensor(t) && !linkedSparse(op, t))
|
||||
continue;
|
||||
// At the moment, we take the index variables in the tensor access
|
||||
// expression in the order in which they appear (conceptually a
|
||||
@@ -513,14 +521,6 @@ static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
|
||||
llvm_unreachable("unexpected SparseIntType");
|
||||
}
|
||||
|
||||
/// Returns true if tensor was set up with sparse storage scheme.
|
||||
static bool linkedSparse(linalg::GenericOp op, unsigned tensor) {
|
||||
if (tensor < op.getNumInputs())
|
||||
return isa_and_nonnull<linalg::SparseTensorFromPointerOp>(
|
||||
op.getInput(tensor).getDefiningOp());
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Generates buffer for the output tensor.
|
||||
static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
|
||||
linalg::GenericOp op, MemRefType denseTp,
|
||||
@@ -1004,7 +1004,7 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
|
||||
if (needsUniv) {
|
||||
types.push_back(indexType);
|
||||
assert(codegen.loops[idx].getType().isa<IndexType>() &&
|
||||
"type_mismatch for universal index");
|
||||
"type mismatch for universal index");
|
||||
operands.push_back(codegen.loops[idx]);
|
||||
}
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Reference in New Issue
Block a user