mirror of
https://github.com/intel/llvm.git
synced 2026-01-20 10:58:11 +08:00
[mlir][sparse] relax constraints on tensor.cast with pre-rewriting
Reviewed By: wrengr Differential Revision: https://reviews.llvm.org/D149489
This commit is contained in:
@@ -142,6 +142,9 @@ FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
|
||||
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
|
||||
SmallVector<Value> &result);
|
||||
|
||||
/// Tests if types are the same when ignoring encoding on ranked tensors.
|
||||
bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
|
||||
|
||||
/// Function to control the folding of constant and extract slice.
|
||||
using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
|
||||
|
||||
|
||||
@@ -449,8 +449,6 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Returns true iff the given sparse tensor encoding attribute has a trailing
|
||||
/// COO region starting at the given level.
|
||||
bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
|
||||
Level startLvl, bool isUnique) {
|
||||
if (!enc ||
|
||||
|
||||
@@ -346,6 +346,45 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
// Fuse a tensor cast into producing operation. Note that a tensor.cast
|
||||
// should really not be used to convert between sparse encodings. Since
|
||||
// the pattern currently appears as a result of some prior rewriting
|
||||
// we make an attempt to repair very obvious cases.
|
||||
// TODO: audit the pure tensor dialect rewriting rules
|
||||
struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
|
||||
public:
|
||||
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::CastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Type srcType = op.getSource().getType();
|
||||
Type dstType = op.getDest().getType();
|
||||
// A nop cast simply folds away.
|
||||
if (srcType == dstType) {
|
||||
rewriter.replaceOp(op, op->getResults());
|
||||
return success();
|
||||
}
|
||||
// See if a sparsity changing cast can be fused into producer.
|
||||
if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
|
||||
if (Operation *def = op.getSource().getDefiningOp()) {
|
||||
if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
|
||||
def->getResult(0).setType(op->getResultTypes()[0]);
|
||||
rewriter.replaceOp(op, def->getResult(0));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
// Repair tensor casts with at least one sparse operand into the
|
||||
// the properly supported sparse_tensor.convert.
|
||||
if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
|
||||
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
|
||||
return success();
|
||||
}
|
||||
// Fail otherwise.
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
|
||||
template <typename ReshapeOp>
|
||||
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
|
||||
@@ -1125,7 +1164,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
|
||||
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd>(
|
||||
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
||||
@@ -110,6 +110,16 @@ LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
|
||||
return success();
|
||||
}
|
||||
|
||||
bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
|
||||
if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
|
||||
if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
|
||||
return rtp1.getShape() == rtp2.getShape() &&
|
||||
rtp1.getElementType() == rtp2.getElementType();
|
||||
return false;
|
||||
}
|
||||
return tp1 == tp2; // default implementation
|
||||
}
|
||||
|
||||
/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
|
||||
/// rank-extending tensor.insert_slice op.
|
||||
static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
|
||||
@@ -1343,18 +1353,6 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
|
||||
getReassociationIndicesAttribute(b, reassociation));
|
||||
}
|
||||
|
||||
// Checks if types are the same, but ignoring encoding on ranked tensors.
|
||||
static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
|
||||
if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
|
||||
if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
|
||||
return rtp1.getShape() == rtp2.getShape() &&
|
||||
rtp1.getElementType() == rtp2.getElementType();
|
||||
return false;
|
||||
}
|
||||
// Default implementation.
|
||||
return tp1 == tp2;
|
||||
}
|
||||
|
||||
template <typename TensorReshapeOp, bool isExpansion = std::is_same<
|
||||
TensorReshapeOp, ExpandShapeOp>::value>
|
||||
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
|
||||
@@ -1367,7 +1365,7 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
|
||||
auto maps = op.getReassociationMaps();
|
||||
RankedTensorType expectedType =
|
||||
CollapseShapeOp::inferCollapsedType(expandedType, maps);
|
||||
if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
|
||||
if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
|
||||
return op.emitOpError("expected collapsed type to be ")
|
||||
<< expectedType << ", but got " << collapsedType;
|
||||
return success();
|
||||
|
||||
0
mlir/test/Dialect/SparseTensor/rewriting.mlir → mlir/test/Dialect/SparseTensor/post_rewriting.mlir
Executable file → Normal file
0
mlir/test/Dialect/SparseTensor/rewriting.mlir → mlir/test/Dialect/SparseTensor/post_rewriting.mlir
Executable file → Normal file
45
mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
Normal file
45
mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
Normal file
@@ -0,0 +1,45 @@
|
||||
// RUN: mlir-opt %s -pre-sparsification-rewrite | FileCheck %s
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{
|
||||
dimLevelType = ["compressed"]
|
||||
}>
|
||||
|
||||
#SortedCOO = #sparse_tensor.encoding<{
|
||||
dimLevelType = [ "compressed-nu", "singleton" ]
|
||||
}>
|
||||
|
||||
#Slice = #sparse_tensor.encoding<{
|
||||
dimLevelType = [ "compressed-nu", "singleton" ],
|
||||
slice = [ (?, 1, 1), (?, 3, 1) ]
|
||||
}>
|
||||
|
||||
// CHECK-LABEL: func @sparse_nop_cast(
|
||||
// CHECK-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>)
|
||||
// CHECK: return %[[A]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
func.func @sparse_nop_cast(%a : tensor<?xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
|
||||
%0 = tensor.cast %a : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
|
||||
%1 = tensor.cast %0 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
|
||||
%2 = tensor.cast %1 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
|
||||
return %2 : tensor<?xf32, #SparseVector>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sparse_repair_cast(
|
||||
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>)
|
||||
// CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor<?xf32> to tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>
|
||||
// CHECK: return %[[C]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
func.func @sparse_repair_cast(%a : tensor<?xf32>) -> tensor<?xf32, #SparseVector> {
|
||||
%0 = tensor.cast %a : tensor<?xf32> to tensor<?xf32, #SparseVector>
|
||||
return %0 : tensor<?xf32, #SparseVector>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sparse_fuse_slice(
|
||||
// CHECK-SAME: %[[A:.*]]: tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>>)
|
||||
// CHECK: %[[E:.*]] = tensor.extract_slice %[[A]][1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: return %[[C]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64, #SortedCOO> {
|
||||
%extracted_slice = tensor.extract_slice %a[1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #SortedCOO> to tensor<1x3xi64>
|
||||
%cast = tensor.cast %extracted_slice : tensor<1x3xi64> to tensor<1x3xi64, #Slice>
|
||||
%0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO>
|
||||
return %0 : tensor<1x3xi64, #SortedCOO>
|
||||
}
|
||||
Reference in New Issue
Block a user