[mlir][sparse] relax constraints on tensor.cast with pre-rewriting

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D149489
This commit is contained in:
Aart Bik
2023-05-01 12:56:26 -07:00
parent dc049a4ea6
commit 9a018a7b48
6 changed files with 99 additions and 16 deletions

View File

@@ -142,6 +142,9 @@ FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
SmallVector<Value> &result);
/// Tests if types are the same when ignoring encoding on ranked tensors.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
/// Function to control the folding of constant and extract slice.
using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;

View File

@@ -449,8 +449,6 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
return nullptr;
}
/// Returns true iff the given sparse tensor encoding attribute has a trailing
/// COO region starting at the given level.
bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
Level startLvl, bool isUnique) {
if (!enc ||

View File

@@ -346,6 +346,45 @@ private:
}
};
// Fuse a tensor cast into producing operation. Note that a tensor.cast
// should really not be used to convert between sparse encodings. Since
// the pattern currently appears as a result of some prior rewriting
// we make an attempt to repair very obvious cases.
// TODO: audit the pure tensor dialect rewriting rules
struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
public:
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CastOp op,
PatternRewriter &rewriter) const override {
Type srcType = op.getSource().getType();
Type dstType = op.getDest().getType();
// A nop cast simply folds away.
if (srcType == dstType) {
rewriter.replaceOp(op, op->getResults());
return success();
}
// See if a sparsity changing cast can be fused into producer.
if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
if (Operation *def = op.getSource().getDefiningOp()) {
if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
def->getResult(0).setType(op->getResultTypes()[0]);
rewriter.replaceOp(op, def->getResult(0));
return success();
}
}
}
// Repair tensor casts with at least one sparse operand into the
// the properly supported sparse_tensor.convert.
if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
return success();
}
// Fail otherwise.
return failure();
}
};
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
template <typename ReshapeOp>
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
@@ -1125,7 +1164,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd>(
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast>(
patterns.getContext());
}

View File

@@ -110,6 +110,16 @@ LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
return success();
}
bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
return rtp1.getShape() == rtp2.getShape() &&
rtp1.getElementType() == rtp2.getElementType();
return false;
}
return tp1 == tp2; // default implementation
}
/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
/// rank-extending tensor.insert_slice op.
static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
@@ -1343,18 +1353,6 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
getReassociationIndicesAttribute(b, reassociation));
}
// Checks if types are the same, but ignoring encoding on ranked tensors.
static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
return rtp1.getShape() == rtp2.getShape() &&
rtp1.getElementType() == rtp2.getElementType();
return false;
}
// Default implementation.
return tp1 == tp2;
}
template <typename TensorReshapeOp, bool isExpansion = std::is_same<
TensorReshapeOp, ExpandShapeOp>::value>
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
@@ -1367,7 +1365,7 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
auto maps = op.getReassociationMaps();
RankedTensorType expectedType =
CollapseShapeOp::inferCollapsedType(expandedType, maps);
if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
return op.emitOpError("expected collapsed type to be ")
<< expectedType << ", but got " << collapsedType;
return success();

View File

@@ -0,0 +1,45 @@
// RUN: mlir-opt %s -pre-sparsification-rewrite | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = ["compressed"]
}>
#SortedCOO = #sparse_tensor.encoding<{
dimLevelType = [ "compressed-nu", "singleton" ]
}>
#Slice = #sparse_tensor.encoding<{
dimLevelType = [ "compressed-nu", "singleton" ],
slice = [ (?, 1, 1), (?, 3, 1) ]
}>
// CHECK-LABEL: func @sparse_nop_cast(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>)
// CHECK: return %[[A]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
func.func @sparse_nop_cast(%a : tensor<?xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
%0 = tensor.cast %a : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
%1 = tensor.cast %0 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
%2 = tensor.cast %1 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
return %2 : tensor<?xf32, #SparseVector>
}
// CHECK-LABEL: func @sparse_repair_cast(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>)
// CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor<?xf32> to tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>
// CHECK: return %[[C]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
func.func @sparse_repair_cast(%a : tensor<?xf32>) -> tensor<?xf32, #SparseVector> {
%0 = tensor.cast %a : tensor<?xf32> to tensor<?xf32, #SparseVector>
return %0 : tensor<?xf32, #SparseVector>
}
// CHECK-LABEL: func @sparse_fuse_slice(
// CHECK-SAME: %[[A:.*]]: tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>>)
// CHECK: %[[E:.*]] = tensor.extract_slice %[[A]][1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: return %[[C]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64, #SortedCOO> {
%extracted_slice = tensor.extract_slice %a[1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #SortedCOO> to tensor<1x3xi64>
%cast = tensor.cast %extracted_slice : tensor<1x3xi64> to tensor<1x3xi64, #Slice>
%0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO>
return %0 : tensor<1x3xi64, #SortedCOO>
}