mirror of
https://github.com/intel/llvm.git
synced 2026-01-20 10:58:11 +08:00
[MLIR] Enable pattern only for scf.forall op (#110230)
The init args shape might change in the loop body and hence the pattern doesn't hold true.
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
@@ -131,11 +132,25 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
|
||||
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
|
||||
if (!blockArg)
|
||||
return failure();
|
||||
auto loopLikeOp =
|
||||
dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
|
||||
if (!loopLikeOp)
|
||||
// TODO: Enable this for loopLikeInterface. Restricting for scf.for
|
||||
// because the init args shape might change in the loop body.
|
||||
// For e.g.:
|
||||
// ```
|
||||
// %0 = tensor.empty(%c1) : tensor<?xf32>
|
||||
// %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
|
||||
// tensor<?xf32> {
|
||||
// %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
|
||||
// %2 = arith.addi %c1, %1 : index
|
||||
// %3 = tensor.empty(%2) : tensor<?xf32>
|
||||
// scf.yield %3 : tensor<?xf32>
|
||||
// }
|
||||
//
|
||||
// ```
|
||||
auto forAllOp =
|
||||
dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
|
||||
if (!forAllOp)
|
||||
return failure();
|
||||
Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
|
||||
Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
|
||||
rewriter.modifyOpInPlace(
|
||||
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user