[MLIR] Enable pattern only for scf.forall op (#110230)

The init args shape might change in the loop body and hence the pattern
doesn't hold true.
This commit is contained in:
Prashant Kumar
2024-10-17 18:32:03 +05:30
committed by GitHub
parent d9cd607200
commit c1047ba836

View File

@@ -18,6 +18,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -131,11 +132,25 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
auto loopLikeOp =
dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
if (!loopLikeOp)
// TODO: Enable this for loopLikeInterface. Restricting for scf.for
// because the init args shape might change in the loop body.
// For e.g.:
// ```
// %0 = tensor.empty(%c1) : tensor<?xf32>
// %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
// tensor<?xf32> {
// %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
// %2 = arith.addi %c1, %1 : index
// %3 = tensor.empty(%2) : tensor<?xf32>
// scf.yield %3 : tensor<?xf32>
// }
//
// ```
auto forAllOp =
dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
if (!forAllOp)
return failure();
Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();