From dbdb4affa00785c31675be0535b4dc89136b8502 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Fri, 21 Apr 2023 14:40:31 -0700 Subject: [PATCH] [mlir][sparse] avoid slice rewriting when conditions are not met Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D148964 --- .../Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 4e1e66d8bc0f..c1cb0926622f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1165,11 +1165,9 @@ public: MLIRContext *ctx = op.getContext(); auto srcEnc = getSparseTensorEncoding(op.getSourceType()); auto dstEnc = getSparseTensorEncoding(op.getResult().getType()); - if (!srcEnc && !dstEnc) - return failure(); - // TODO: We should check these in ExtractSliceOp::verify. - assert(srcEnc && dstEnc && dstEnc.isSlice()); + if (!srcEnc || !dstEnc || !dstEnc.isSlice()) + return failure(); assert(srcEnc.getDimLevelType() == dstEnc.getDimLevelType()); assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering()); assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering());