diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index 41a14575ed10..a00c9c31256c 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -283,7 +283,13 @@ public: } bool operator!=(const LevelType lhs) const { return !(*this == lhs); } - LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); } + LevelType stripStorageIrrelevantProperties() const { + // Properties other than `SoA` do not change the storage scheme of the + // sparse tensor. + constexpr uint64_t mask = + 0xffff & ~static_cast(LevelPropNonDefault::SoA); + return LevelType(lvlBits & ~mask); + } /// Get N of NOutOfM level type. constexpr uint64_t getN() const { diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h index 24a5640d820e..1a090ddb782f 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -24,6 +24,7 @@ struct COOSegment { std::pair lvlRange; // [low, high) bool isSoA; + bool isAoS() const { return !isSoA; } bool isSegmentStart(Level l) const { return l == lvlRange.first; } bool inSegment(Level l) const { return l >= lvlRange.first && l < lvlRange.second; @@ -337,7 +338,9 @@ public: /// Returns the starting level of this sparse tensor type for a /// trailing COO region that spans **at least** two levels. If /// no such COO region is found, then returns the level-rank. - Level getCOOStart() const; + /// + /// DEPRECATED: use getCOOSegment instead; + Level getAoSCOOStart() const; /// Returns [un]ordered COO type for this sparse tensor type. RankedTensorType getCOOType(bool ordered) const; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 53e78d2c28b1..af7b85d45877 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -182,7 +182,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, unsigned stride = 1; if (kind == SparseTensorFieldKind::CrdMemRef) { assert(lvl.has_value()); - const Level cooStart = SparseTensorType(enc).getCOOStart(); + const Level cooStart = SparseTensorType(enc).getAoSCOOStart(); const Level lvlRank = enc.getLvlRank(); if (lvl.value() >= cooStart && lvl.value() < lvlRank) { lvl = cooStart; @@ -811,10 +811,10 @@ bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, return !isUnique || isUniqueLvl(lvlRank - 1); } -Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const { +Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const { SmallVector coo = getCOOSegments(); - if (!coo.empty()) { - assert(coo.size() == 1); + assert(coo.size() == 1 || coo.empty()); + if (!coo.empty() && coo.front().isAoS()) { return coo.front().lvlRange.first; } return lvlRank; @@ -1051,7 +1051,7 @@ static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { SmallVector lts; for (auto lt : enc.getLvlTypes()) - lts.push_back(lt.stripProperties()); + lts.push_back(lt.stripStorageIrrelevantProperties()); return SparseTensorEncodingAttr::get( enc.getContext(), lts, @@ -1137,7 +1137,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, return op->emitError("the sparse-tensor must have an encoding attribute"); // Verifies the trailing COO. - Level cooStartLvl = stt.getCOOStart(); + Level cooStartLvl = stt.getAoSCOOStart(); if (cooStartLvl < stt.getLvlRank()) { // We only supports trailing COO for now, must be the last input. auto cooTp = llvm::cast(lvlTps.back()); @@ -1452,7 +1452,7 @@ LogicalResult ToCoordinatesOp::verify() { LogicalResult ToCoordinatesBufferOp::verify() { auto stt = getSparseTensorType(getTensor()); - if (stt.getCOOStart() >= stt.getLvlRank()) + if (stt.getAoSCOOStart() >= stt.getLvlRank()) return emitError("expected sparse tensor with a COO region"); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index d4459c6ea1e5..0ccb11f3a6b8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -194,7 +194,7 @@ static void createAllocFields(OpBuilder &builder, Location loc, valHeuristic = builder.create(loc, valHeuristic, lvlSizesValues[lvl]); } else if (sizeHint) { - if (stt.getCOOStart() == 0) { + if (stt.getAoSCOOStart() == 0) { posHeuristic = constantIndex(builder, loc, 2); crdHeuristic = builder.create( loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS @@ -1316,7 +1316,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern { Value posBack = c0; // index to the last value in the position array Value memSize = c1; // memory size for current array - Level trailCOOStart = stt.getCOOStart(); + Level trailCOOStart = stt.getAoSCOOStart(); Level trailCOORank = stt.getLvlRank() - trailCOOStart; // Sets up SparseTensorSpecifier. for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { @@ -1453,7 +1453,7 @@ struct SparseNewConverter : public OpConversionPattern { const auto dstTp = getSparseTensorType(op.getResult()); // Creating COO with NewOp is handled by direct IR codegen. All other cases // are handled by rewriting. - if (!dstTp.hasEncoding() || dstTp.getCOOStart() != 0) + if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0) return failure(); // Implement as follows: diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 7326a6a38112..2ccb2361b5ef 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1180,7 +1180,7 @@ struct NewRewriter : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto stt = getSparseTensorType(op.getResult()); - if (!stt.hasEncoding() || stt.getCOOStart() == 0) + if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0) return failure(); // Implement the NewOp as follows: diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp index 75a438914918..b888dfadb9c7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp @@ -568,7 +568,7 @@ Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc, const auto srcTp = getSparseTensorType(tensor); const Type crdTp = srcTp.getCrdType(); const Type memTp = - get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getCOOStart()); + get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getAoSCOOStart()); return builder.create(loc, memTp, tensor, builder.getIndexAttr(lvl)); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp index 3ab4157475cd..6ac26ad550f9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp @@ -103,7 +103,7 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc, Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView( OpBuilder &builder, Location loc, Level lvl) const { - const Level cooStart = rType.getCOOStart(); + const Level cooStart = rType.getAoSCOOStart(); if (lvl < cooStart) return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h index 3a61ec7a2236..c2f631605bf4 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h @@ -137,7 +137,7 @@ public: } Value getAOSMemRef() const { - const Level cooStart = rType.getCOOStart(); + const Level cooStart = rType.getAoSCOOStart(); assert(cooStart < rType.getLvlRank()); return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart); } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir index aaf15ecc681f..16252c1005eb 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir @@ -34,6 +34,10 @@ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }> +#SortedCOOSoA = #sparse_tensor.encoding<{ + map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)) +}> + #CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> @@ -50,7 +54,7 @@ module { func.func @add_coo_csr(%arga: tensor<8x8xf32, #CSR>, - %argb: tensor<8x8xf32, #SortedCOO>) + %argb: tensor<8x8xf32, #SortedCOOSoA>) -> tensor<8x8xf32> { %empty = tensor.empty() : tensor<8x8xf32> %zero = arith.constant 0.000000e+00 : f32 @@ -59,7 +63,7 @@ module { outs(%empty : tensor<8x8xf32>) -> tensor<8x8xf32> %0 = linalg.generic #trait ins(%arga, %argb: tensor<8x8xf32, #CSR>, - tensor<8x8xf32, #SortedCOO>) + tensor<8x8xf32, #SortedCOOSoA>) outs(%init: tensor<8x8xf32>) { ^bb(%a: f32, %b: f32, %x: f32): %0 = arith.addf %a, %b : f32 @@ -69,7 +73,7 @@ module { } func.func @add_coo_coo(%arga: tensor<8x8xf32, #SortedCOO>, - %argb: tensor<8x8xf32, #SortedCOO>) + %argb: tensor<8x8xf32, #SortedCOOSoA>) -> tensor<8x8xf32> { %empty = tensor.empty() : tensor<8x8xf32> %zero = arith.constant 0.000000e+00 : f32 @@ -78,7 +82,7 @@ module { outs(%empty : tensor<8x8xf32>) -> tensor<8x8xf32> %0 = linalg.generic #trait ins(%arga, %argb: tensor<8x8xf32, #SortedCOO>, - tensor<8x8xf32, #SortedCOO>) + tensor<8x8xf32, #SortedCOOSoA>) outs(%init: tensor<8x8xf32>) { ^bb(%a: f32, %b: f32, %x: f32): %0 = arith.addf %a, %b : f32 @@ -88,12 +92,12 @@ module { } func.func @add_coo_coo_out_coo(%arga: tensor<8x8xf32, #SortedCOO>, - %argb: tensor<8x8xf32, #SortedCOO>) + %argb: tensor<8x8xf32, #SortedCOOSoA>) -> tensor<8x8xf32, #SortedCOO> { %init = tensor.empty() : tensor<8x8xf32, #SortedCOO> %0 = linalg.generic #trait ins(%arga, %argb: tensor<8x8xf32, #SortedCOO>, - tensor<8x8xf32, #SortedCOO>) + tensor<8x8xf32, #SortedCOOSoA>) outs(%init: tensor<8x8xf32, #SortedCOO>) { ^bb(%a: f32, %b: f32, %x: f32): %0 = arith.addf %a, %b : f32 @@ -104,7 +108,7 @@ module { func.func @add_coo_dense(%arga: tensor<8x8xf32>, - %argb: tensor<8x8xf32, #SortedCOO>) + %argb: tensor<8x8xf32, #SortedCOOSoA>) -> tensor<8x8xf32> { %empty = tensor.empty() : tensor<8x8xf32> %zero = arith.constant 0.000000e+00 : f32 @@ -113,7 +117,7 @@ module { outs(%empty : tensor<8x8xf32>) -> tensor<8x8xf32> %0 = linalg.generic #trait ins(%arga, %argb: tensor<8x8xf32>, - tensor<8x8xf32, #SortedCOO>) + tensor<8x8xf32, #SortedCOOSoA>) outs(%init: tensor<8x8xf32>) { ^bb(%a: f32, %b: f32, %x: f32): %0 = arith.addf %a, %b : f32 @@ -154,19 +158,19 @@ module { %COO_A = sparse_tensor.convert %A : tensor<8x8xf32> to tensor<8x8xf32, #SortedCOO> %COO_B = sparse_tensor.convert %B - : tensor<8x8xf32> to tensor<8x8xf32, #SortedCOO> + : tensor<8x8xf32> to tensor<8x8xf32, #SortedCOOSoA> %C1 = call @add_coo_dense(%A, %COO_B) : (tensor<8x8xf32>, - tensor<8x8xf32, #SortedCOO>) + tensor<8x8xf32, #SortedCOOSoA>) -> tensor<8x8xf32> %C2 = call @add_coo_csr(%CSR_A, %COO_B) : (tensor<8x8xf32, #CSR>, - tensor<8x8xf32, #SortedCOO>) + tensor<8x8xf32, #SortedCOOSoA>) -> tensor<8x8xf32> %C3 = call @add_coo_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>, - tensor<8x8xf32, #SortedCOO>) + tensor<8x8xf32, #SortedCOOSoA>) -> tensor<8x8xf32> %COO_RET = call @add_coo_coo_out_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>, - tensor<8x8xf32, #SortedCOO>) + tensor<8x8xf32, #SortedCOOSoA>) -> tensor<8x8xf32, #SortedCOO> %C4 = sparse_tensor.convert %COO_RET : tensor<8x8xf32, #SortedCOO> to tensor<8x8xf32> // @@ -204,7 +208,7 @@ module { bufferization.dealloc_tensor %C4 : tensor<8x8xf32> bufferization.dealloc_tensor %CSR_A : tensor<8x8xf32, #CSR> bufferization.dealloc_tensor %COO_A : tensor<8x8xf32, #SortedCOO> - bufferization.dealloc_tensor %COO_B : tensor<8x8xf32, #SortedCOO> + bufferization.dealloc_tensor %COO_B : tensor<8x8xf32, #SortedCOOSoA> bufferization.dealloc_tensor %COO_RET : tensor<8x8xf32, #SortedCOO>