[mlir][sparse] avoid unnecessary tmp COO buffer and convert when lowering ConcatentateOp.

When concat along dim 0, and all inputs/outputs are ordered with identity dimension ordering,
the concatenated coordinates will be yield in lexOrder, thus no need to sort.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140228
This commit is contained in:
Peiming Liu
2022-12-16 18:12:05 +00:00
parent 470bc76b13
commit a3672add76
5 changed files with 44 additions and 12 deletions

View File

@@ -173,6 +173,12 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// Constructs a new encoding with the dimOrdering and higherOrdering
/// reset to the default/identity.
SparseTensorEncodingAttr withoutOrdering() const;
/// Return true if every level is dense in the encoding.
bool isAllDense() const;
/// Return true if the encoding has an identity dimension ordering.
bool hasIdDimOrdering() const;
}];
let genVerifyDecl = 1;

View File

@@ -63,6 +63,14 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
getPointerBitWidth(), getIndexBitWidth());
}
bool SparseTensorEncodingAttr::isAllDense() const {
return llvm::all_of(getDimLevelType(), isDenseDLT);
}
bool SparseTensorEncodingAttr::hasIdDimOrdering() const {
return !getDimOrdering() || getDimOrdering().isIdentity();
}
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
@@ -172,7 +180,7 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
}
printer << " ]";
// Print remaining members only for non-default values.
if (getDimOrdering() && !getDimOrdering().isIdentity())
if (!hasIdDimOrdering())
printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
if (getHigherOrdering())
printer << ", higherOrdering = affine_map<" << getHigherOrdering() << ">";

View File

@@ -1349,8 +1349,7 @@ public:
bool allDense = false;
Value dstTensor;
if (encDst) {
allDense = llvm::all_of(encDst.getDimLevelType(),
[](DimLevelType dlt) { return isDenseDLT(dlt); });
allDense = encDst.isAllDense();
// Start a new COO or an initialized annotated all dense sparse tensor.
dst = params.genBuffers(encDst, sizes, dstTp)
.genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO);

View File

@@ -525,14 +525,35 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// %t = convert_to_dest_tensor(%tmp)
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
Value dst; // Destination tensor for inserting source tensor values.
bool allDense = false;
bool needTmpCOO = true;
if (encDst) {
allDense = llvm::all_of(encDst.getDimLevelType(),
[](DimLevelType dlt) { return isDenseDLT(dlt); });
bool allDense = encDst.isAllDense();
bool allOrdered = false;
// When concatenating on dimension 0, and all inputs are sorted and have
// an identity dimOrdering, the concatenate will generate coords in
// lexOrder thus no need for the tmp COO buffer.
// TODO: When conDim != 0, as long as conDim is the first dimension
// in all input/output buffers, and all input/output buffers have the same
// dimOrdering, the tmp COO buffer is still unnecessary (e.g, concatenate
// CSC matrices along column).
if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) {
for (auto i : op.getInputs()) {
auto rtp = i.getType().cast<RankedTensorType>();
auto srcEnc = getSparseTensorEncoding(rtp);
if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) {
allOrdered = true;
continue;
}
allOrdered = false;
break;
}
}
needTmpCOO = !allDense && !allOrdered;
SmallVector<Value> dynSizes;
getDynamicSizes(dstTp, sizes, dynSizes);
RankedTensorType tp = dstTp;
if (!allDense) {
if (needTmpCOO) {
tp = getUnorderedCOOFromType(dstTp);
encDst = getSparseTensorEncoding(tp);
}
@@ -596,7 +617,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
if (encDst) {
dst = rewriter.create<LoadOp>(loc, dst, true);
if (!allDense) {
if (needTmpCOO) {
Value tmpCoo = dst;
dst = rewriter.create<ConvertOp>(loc, dstTp, tmpCoo).getResult();
rewriter.create<DeallocTensorOp>(loc, tmpCoo);

View File

@@ -79,8 +79,7 @@
// CHECK: scf.yield %[[RET_6]]
// CHECK: }
// CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
// CHECK: %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor
// CHECK: return %[[TMP_22]] : tensor<9x4xf64, #sparse_tensor
// CHECK: return %[[TMP_23]] : tensor<9x4xf64, #sparse_tensor
func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
%arg1: tensor<3x4xf64, #DCSR>,
%arg2: tensor<4x4xf64, #DCSR>)
@@ -166,8 +165,7 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>,
// CHECK: scf.yield %[[RET_6]]
// CHECK: }
// CHECK: %[[TMP_23:.*]] = sparse_tensor.load %[[RET_3]] hasInserts
// CHECK: %[[TMP_22:.*]] = sparse_tensor.convert %[[TMP_23]] : tensor<?x?xf64, #sparse_tensor
// CHECK: return %[[TMP_22]] : tensor<?x?xf64, #sparse_tensor
// CHECK: return %[[TMP_23]] : tensor<?x?xf64, #sparse_tensor
func.func @concat_sparse_sparse_dynamic(%arg0: tensor<2x4xf64, #DCSR>,
%arg1: tensor<3x4xf64, #DCSR>,
%arg2: tensor<4x4xf64, #DCSR>)