mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization (#84225)
The current canonicalization of `memref.dim` operating on the result of
`memref.reshape` into `memref.load` is incorrect as it doesn't check
whether the `index` operand of `memref.dim` dominates the source
`memref.reshape` op. It always introduces `memref.load` right after
`memref.reshape` to ensure the `memref` is not mutated before the
`memref.load` call. As a result, the following error is observed:
```
$> mlir-opt --canonicalize input.mlir
func.func @reshape_dim(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = memref.dim %reshape, %0 : memref<*xf32>
return %dim : index
}
```
results in:
```
dominator.mlir:22:12: error: operand #1 does not dominate this use
%dim = memref.dim %reshape, %0 : memref<*xf32>
^
dominator.mlir:22:12: note: see current operation: %1 = "memref.load"(%arg1, %2) <{nontemporal = false}> : (memref<?xindex>, index) -> index
dominator.mlir:21:10: note: operand defined here (op in the same block)
%0 = arith.muli %arg2, %c4 : index
```
Properly fixing this issue requires a dominator analysis which is
expensive to run within a canonicalization pattern. So, this patch fixes
the canonicalization pattern by being more strict/conservative about the
legality condition in which we perform this canonicalization.
The more general pattern is also added to `tensor.dim`. Since tensors are
immutable we don't need to worry about where to introduce the
`tensor.extract` call after canonicalization.
This commit is contained in:
@@ -1080,7 +1080,37 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
|
||||
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
|
||||
|
||||
if (!reshape)
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
dim, "Dim op is not defined by a reshape op.");
|
||||
|
||||
// dim of a memref reshape can be folded if dim.getIndex() dominates the
|
||||
// reshape. Instead of using `DominanceInfo` (which is usually costly) we
|
||||
// cheaply check that either of the following conditions hold:
|
||||
// 1. dim.getIndex() is defined in the same block as reshape but before
|
||||
// reshape.
|
||||
// 2. dim.getIndex() is defined in a parent block of
|
||||
// reshape.
|
||||
|
||||
// Check condition 1
|
||||
if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
|
||||
if (auto *definingOp = dim.getIndex().getDefiningOp()) {
|
||||
if (reshape->isBeforeInBlock(definingOp)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
dim,
|
||||
"dim.getIndex is not defined before reshape in the same block.");
|
||||
}
|
||||
} // else dim.getIndex is a block argument to reshape->getBlock and
|
||||
// dominates reshape
|
||||
} // Check condition 2
|
||||
else if (dim->getBlock() != reshape->getBlock() &&
|
||||
!dim.getIndex().getParentRegion()->isProperAncestor(
|
||||
reshape->getParentRegion())) {
|
||||
// If dim and reshape are in the same block but dim.getIndex() isn't, we
|
||||
// already know dim.getIndex() dominates reshape without calling
|
||||
// `isProperAncestor`
|
||||
return rewriter.notifyMatchFailure(
|
||||
dim, "dim.getIndex does not dominate reshape.");
|
||||
}
|
||||
|
||||
// Place the load directly after the reshape to ensure that the shape memref
|
||||
// was not mutated.
|
||||
|
||||
@@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
|
||||
/// operand.
|
||||
struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DimOp dim,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
|
||||
|
||||
if (!reshape)
|
||||
return failure();
|
||||
|
||||
// Since tensors are immutable we don't need to worry about where to place
|
||||
// the extract call
|
||||
rewriter.setInsertionPointAfter(dim);
|
||||
Location loc = dim.getLoc();
|
||||
Value extract =
|
||||
rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
|
||||
if (extract.getType() != dim.getType())
|
||||
extract =
|
||||
rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
|
||||
rewriter.replaceOp(dim, extract);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
|
||||
results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -313,6 +313,59 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
|
||||
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
|
||||
// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
|
||||
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
|
||||
// CHECK-NOT: memref.dim
|
||||
// CHECK: return %[[DIM]] : index
|
||||
func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
|
||||
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
%dim = memref.dim %reshape, %arg2 : memref<*xf32>
|
||||
return %dim : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_memref_reshape_for(
|
||||
// CHECK: memref.reshape
|
||||
// CHECK: memref.dim
|
||||
// CHECK-NOT: memref.load
|
||||
func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
|
||||
%0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
|
||||
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
|
||||
%2 = memref.dim %0, %arg2 : memref<*xf32>
|
||||
%3 = arith.muli %arg3, %2 : index
|
||||
scf.yield %3 : index
|
||||
}
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
|
||||
// CHECK: memref.reshape
|
||||
// CHECK: memref.dim
|
||||
// CHECK-NOT: memref.load
|
||||
func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
|
||||
%c4 = arith.constant 4 : index
|
||||
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
%0 = arith.muli %arg2, %c4 : index
|
||||
%dim = memref.dim %reshape, %0 : memref<*xf32>
|
||||
return %dim : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @alloc_const_fold
|
||||
func.func @alloc_const_fold() -> memref<?xf32> {
|
||||
// CHECK-NEXT: memref.alloc() : memref<4xf32>
|
||||
|
||||
@@ -2287,3 +2287,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
|
||||
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
|
||||
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
|
||||
// CHECK: return %[[SRC]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_reshape(
|
||||
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
|
||||
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
|
||||
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
|
||||
// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
|
||||
// CHECK-NOT: tensor.store
|
||||
// CHECK-NOT: tensor.dim
|
||||
// CHECK-NOT: tensor.reshape
|
||||
// CHECK: return %[[DIM]] : index
|
||||
func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
|
||||
-> index {
|
||||
%c3 = arith.constant 3 : index
|
||||
%0 = tensor.reshape %arg0(%arg1)
|
||||
: (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// Update the shape to test that the load ends up in the right place.
|
||||
tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
|
||||
%1 = tensor.dim %0, %c3 : tensor<*xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_reshape_i32(
|
||||
// CHECK: tensor.extract
|
||||
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
|
||||
// CHECK-NOT: tensor.dim
|
||||
// CHECK-NOT: tensor.reshape
|
||||
// CHECK: return %[[CAST]] : index
|
||||
func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
|
||||
-> index {
|
||||
%c3 = arith.constant 3 : index
|
||||
%0 = tensor.reshape %arg0(%arg1)
|
||||
: (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
%1 = tensor.dim %0, %c3 : tensor<*xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_reshape_for(
|
||||
// CHECK: scf.for
|
||||
// CHECK-NEXT: tensor.extract
|
||||
// CHECK-NOT: tensor.dim
|
||||
// CHECK-NOT: tensor.reshape
|
||||
func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
|
||||
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
|
||||
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
|
||||
%2 = tensor.dim %0, %arg2 : tensor<*xf32>
|
||||
%3 = arith.muli %arg3, %2 : index
|
||||
scf.yield %3 : index
|
||||
}
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
|
||||
// CHECK-LABEL: func @dim_of_reshape_undominated(
|
||||
// CHECK: arith.muli
|
||||
// CHECK-NEXT: tensor.extract
|
||||
// CHECK-NOT: tensor.dim
|
||||
// CHECK-NOT: tensor.reshape
|
||||
func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
|
||||
%c4 = arith.constant 4 : index
|
||||
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
%0 = arith.muli %arg2, %c4 : index
|
||||
%dim = tensor.dim %reshape, %0 : tensor<*xf32>
|
||||
return %dim : index
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user