[mlir][affine] implement promoteIfSingleIteration for AffineForOp (#72547)

This commit is contained in:
Maksim Levental
2023-11-19 12:27:41 -06:00
committed by GitHub
parent 99387e33dc
commit aa6be2f7c9
11 changed files with 190 additions and 185 deletions

View File

@@ -15,6 +15,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include <optional>
namespace mlir {
@@ -29,20 +30,6 @@ namespace affine {
class AffineForOp;
class NestedPattern;
/// Returns the trip count of the loop as an affine map with its corresponding
/// operands if the latter is expressible as an affine expression, and nullptr
/// otherwise. This method always succeeds as long as the lower bound is not a
/// multi-result map. The trip count expression is simplified before returning.
/// This method only utilizes map composition to construct lower and upper
/// bounds before computing the trip count expressions
void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
SmallVectorImpl<Value> *operands);
/// Returns the trip count of the loop if it's a constant, std::nullopt
/// otherwise. This uses affine expression analysis and is able to determine
/// constant trip count in non-trivial cases.
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.

View File

@@ -117,7 +117,8 @@ public:
/// Returns the affine map used to access the source memref.
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
AffineMapAttr getSrcMapAttr() {
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
return cast<AffineMapAttr>(
*(*this)->getInherentAttr(getSrcMapAttrStrName()));
}
/// Returns the source memref affine map indices for this DMA operation.
@@ -156,7 +157,8 @@ public:
/// Returns the affine map used to access the destination memref.
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
AffineMapAttr getDstMapAttr() {
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
return cast<AffineMapAttr>(
*(*this)->getInherentAttr(getDstMapAttrStrName()));
}
/// Returns the destination memref indices for this DMA operation.
@@ -185,7 +187,8 @@ public:
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
return cast<AffineMapAttr>(
*(*this)->getInherentAttr(getTagMapAttrStrName()));
}
/// Returns the tag memref indices for this DMA operation.
@@ -307,7 +310,8 @@ public:
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
return cast<AffineMapAttr>(
*(*this)->getInherentAttr(getTagMapAttrStrName()));
}
/// Returns the tag memref index for this DMA operation.
@@ -465,6 +469,23 @@ AffineForOp getForInductionVarOwner(Value val);
/// AffineParallelOp.
AffineParallelOp getAffineParallelInductionVarOwner(Value val);
/// Helper to replace uses of loop carried values (iter_args) and loop
/// yield values while promoting single iteration affine.for ops.
void replaceIterArgsAndYieldResults(AffineForOp forOp);
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning. This method only utilizes map
/// composition to construct lower and upper bounds before computing the trip
/// count expressions.
void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *tripCountMap,
SmallVectorImpl<Value> *tripCountOperands);
/// Returns the trip count of the loop if it's a constant, std::nullopt
/// otherwise. This uses affine expression analysis and is able to determine
/// constant trip count in non-trivial cases.
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
/// Extracts the induction variables from a list of AffineForOps and places them
/// in the output argument `ivs`.
void extractForInductionVars(ArrayRef<AffineForOp> forInsts,

View File

@@ -121,7 +121,7 @@ def AffineForOp : Affine_Op<"for",
ImplicitAffineTerminator, ConditionallySpeculatable,
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
"getSingleUpperBound", "getYieldedValuesMutable",
"getSingleUpperBound", "getYieldedValuesMutable", "promoteIfSingleIteration",
"replaceWithAdditionalYields"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {

View File

@@ -83,10 +83,6 @@ LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
uint64_t unrollJamFactor);
/// Promotes the loop body of a AffineForOp to its containing block if the loop
/// was known to have a single iteration.
LogicalResult promoteIfSingleIteration(AffineForOp forOp);
/// Promotes all single iteration AffineForOp's in the Function, i.e., moves
/// their body into the containing Block.
void promoteSingleIterationLoops(func::FuncOp f);

View File

@@ -12,7 +12,6 @@
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
@@ -20,9 +19,9 @@
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallString.h"
#include <numeric>
#include <optional>
#include <type_traits>
@@ -30,83 +29,6 @@
using namespace mlir;
using namespace mlir::affine;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning. This method only utilizes map
/// composition to construct lower and upper bounds before computing the trip
/// count expressions.
void mlir::affine::getTripCountMapAndOperands(
AffineForOp forOp, AffineMap *tripCountMap,
SmallVectorImpl<Value> *tripCountOperands) {
MLIRContext *context = forOp.getContext();
int64_t step = forOp.getStepAsInt();
int64_t loopSpan;
if (forOp.hasConstantBounds()) {
int64_t lb = forOp.getConstantLowerBound();
int64_t ub = forOp.getConstantUpperBound();
loopSpan = ub - lb;
if (loopSpan < 0)
loopSpan = 0;
*tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
tripCountOperands->clear();
return;
}
auto lbMap = forOp.getLowerBoundMap();
auto ubMap = forOp.getUpperBoundMap();
if (lbMap.getNumResults() != 1) {
*tripCountMap = AffineMap();
return;
}
// Difference of each upper bound expression from the single lower bound
// expression (divided by the step) provides the expressions for the trip
// count map.
AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
lbMap.getResult(0));
auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
lbSplatExpr, context);
AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
AffineValueMap tripCountValueMap;
AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
tripCountValueMap.setResult(i,
tripCountValueMap.getResult(i).ceilDiv(step));
*tripCountMap = tripCountValueMap.getAffineMap();
tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
tripCountValueMap.getOperands().end());
}
/// Returns the trip count of the loop if it's a constant, std::nullopt
/// otherwise. This method uses affine expression analysis (in turn using
/// getTripCount) and is able to determine constant trip count in non-trivial
/// cases.
std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
SmallVector<Value, 4> operands;
AffineMap map;
getTripCountMapAndOperands(forOp, &map, &operands);
if (!map)
return std::nullopt;
// Take the min if all trip counts are constant.
std::optional<uint64_t> tripCount;
for (auto resultExpr : map.getResults()) {
if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
if (tripCount.has_value())
tripCount =
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
else
tripCount = constExpr.getValue();
} else
return std::nullopt;
}
return tripCount;
}
/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.

View File

@@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/AffineExprVisitor.h"
@@ -2448,6 +2450,65 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
}
void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) {
// Replace uses of iter arguments with iter operands (initial values).
OperandRange iterOperands = forOp.getInits();
MutableArrayRef<BlockArgument> iterArgs = forOp.getRegionIterArgs();
for (auto [operand, arg] : llvm::zip(iterOperands, iterArgs))
arg.replaceAllUsesWith(operand);
// Replace uses of loop results with the values yielded by the loop.
ResultRange outerResults = forOp.getResults();
OperandRange innerResults = forOp.getBody()->getTerminator()->getOperands();
for (auto [outer, inner] : llvm::zip(outerResults, innerResults))
outer.replaceAllUsesWith(inner);
}
LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
auto forOp = cast<AffineForOp>(getOperation());
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (!tripCount || *tripCount != 1)
return failure();
// TODO: extend this for arbitrary affine bounds.
if (forOp.getLowerBoundMap().getNumResults() != 1)
return failure();
// Replaces all IV uses to its single iteration value.
BlockArgument iv = forOp.getInductionVar();
Block *parentBlock = forOp->getBlock();
if (!iv.use_empty()) {
if (forOp.hasConstantLowerBound()) {
OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
auto constOp = topBuilder.create<arith::ConstantIndexOp>(
forOp.getLoc(), forOp.getConstantLowerBound());
iv.replaceAllUsesWith(constOp);
} else {
OperandRange lbOperands = forOp.getLowerBoundOperands();
AffineMap lbMap = forOp.getLowerBoundMap();
OpBuilder builder(forOp);
if (lbMap == builder.getDimIdentityMap()) {
// No need of generating an affine.apply.
iv.replaceAllUsesWith(lbOperands[0]);
} else {
auto affineApplyOp =
builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
iv.replaceAllUsesWith(affineApplyOp);
}
}
}
replaceIterArgsAndYieldResults(forOp);
// Move the loop body operations, except for its terminator, to the loop's
// containing block.
forOp.getBody()->back().erase();
parentBlock->getOperations().splice(Block::iterator(forOp),
forOp.getBody()->getOperations());
forOp.erase();
return success();
}
FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
RewriterBase &rewriter, ValueRange newInitOperands,
bool replaceInitOperandUsesInLoop,
@@ -2546,6 +2607,79 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
return nullptr;
}
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning. This method only utilizes map
/// composition to construct lower and upper bounds before computing the trip
/// count expressions.
void mlir::affine::getTripCountMapAndOperands(
AffineForOp forOp, AffineMap *tripCountMap,
SmallVectorImpl<Value> *tripCountOperands) {
MLIRContext *context = forOp.getContext();
int64_t step = forOp.getStepAsInt();
int64_t loopSpan;
if (forOp.hasConstantBounds()) {
int64_t lb = forOp.getConstantLowerBound();
int64_t ub = forOp.getConstantUpperBound();
loopSpan = ub - lb;
if (loopSpan < 0)
loopSpan = 0;
*tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
tripCountOperands->clear();
return;
}
auto lbMap = forOp.getLowerBoundMap();
auto ubMap = forOp.getUpperBoundMap();
if (lbMap.getNumResults() != 1) {
*tripCountMap = AffineMap();
return;
}
// Difference of each upper bound expression from the single lower bound
// expression (divided by the step) provides the expressions for the trip
// count map.
AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
lbMap.getResult(0));
auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
lbSplatExpr, context);
AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
AffineValueMap tripCountValueMap;
AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
tripCountValueMap.setResult(i,
tripCountValueMap.getResult(i).ceilDiv(step));
*tripCountMap = tripCountValueMap.getAffineMap();
tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
tripCountValueMap.getOperands().end());
}
std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
SmallVector<Value, 4> operands;
AffineMap map;
getTripCountMapAndOperands(forOp, &map, &operands);
if (!map)
return std::nullopt;
// Take the min if all trip counts are constant.
std::optional<uint64_t> tripCount;
for (auto resultExpr : map.getResults()) {
if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
if (tripCount.has_value())
tripCount =
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
else
tripCount = constExpr.getValue();
} else
return std::nullopt;
}
return tripCount;
}
/// Extracts the induction variables from a list of AffineForOps and returns
/// them.
void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
@@ -2913,8 +3047,7 @@ static void composeSetAndOperands(IntegerSet &set,
}
/// Canonicalize an affine if op's conditional (integer set + operands).
LogicalResult AffineIfOp::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {
LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
auto set = getIntegerSet();
SmallVector<Value, 4> operands(getOperands());
composeSetAndOperands(set, operands);
@@ -3005,11 +3138,11 @@ static LogicalResult
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
Operation::operand_range mapOperands,
MemRefType memrefType, unsigned numIndexOperands) {
AffineMap map = mapAttr.getValue();
if (map.getNumResults() != memrefType.getRank())
return op->emitOpError("affine map num results must equal memref rank");
if (map.getNumInputs() != numIndexOperands)
return op->emitOpError("expects as many subscripts as affine map inputs");
AffineMap map = mapAttr.getValue();
if (map.getNumResults() != memrefType.getRank())
return op->emitOpError("affine map num results must equal memref rank");
if (map.getNumInputs() != numIndexOperands)
return op->emitOpError("expects as many subscripts as affine map inputs");
Region *scope = getAffineScope(op);
for (auto idx : mapOperands) {

View File

@@ -219,13 +219,14 @@ void AffineDataCopyGeneration::runOnOperation() {
// Promote any single iteration loops in the copy nests and collect
// load/stores to simplify.
IRRewriter rewriter(f.getContext());
SmallVector<Operation *, 4> copyOps;
for (Operation *nest : copyNests)
// With a post order walk, the erasure of loops does not affect
// continuation of the walk or the collection of load/store ops.
nest->walk([&](Operation *op) {
if (auto forOp = dyn_cast<AffineForOp>(op))
(void)promoteIfSingleIteration(forOp);
(void)forOp.promoteIfSingleIteration(rewriter);
else if (isa<AffineLoadOp, AffineStoreOp>(op))
copyOps.push_back(op);
});

View File

@@ -456,6 +456,7 @@ void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
(getSliceIterationCount(sliceTripCountMap) == 1));
};
IRRewriter rewriter(srcForOp.getContext());
// Fix up and if possible, eliminate single iteration loops.
for (AffineForOp forOp : sliceLoops) {
if (isLoopParallelAndContainsReduction(forOp) &&
@@ -463,9 +464,8 @@ void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
// Patch reduction loop - only ones that are sibling-fused with the
// destination loop - into the parent loop.
(void)promoteSingleIterReductionLoop(forOp, true);
else
// Promote any single iteration slice loops.
(void)promoteIfSingleIteration(forOp);
else // Promote any single iteration slice loops.
(void)forOp.promoteIfSingleIteration(rewriter);
}
}

View File

@@ -110,68 +110,6 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
lb.erase();
}
/// Helper to replace uses of loop carried values (iter_args) and loop
/// yield values while promoting single iteration affine.for ops.
static void replaceIterArgsAndYieldResults(AffineForOp forOp) {
// Replace uses of iter arguments with iter operands (initial values).
auto iterOperands = forOp.getInits();
auto iterArgs = forOp.getRegionIterArgs();
for (auto e : llvm::zip(iterOperands, iterArgs))
std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
// Replace uses of loop results with the values yielded by the loop.
auto outerResults = forOp.getResults();
auto innerResults = forOp.getBody()->getTerminator()->getOperands();
for (auto e : llvm::zip(outerResults, innerResults))
std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
}
/// Promotes the loop body of a forOp to its containing block if the forOp
/// was known to have a single iteration.
LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (!tripCount || *tripCount != 1)
return failure();
// TODO: extend this for arbitrary affine bounds.
if (forOp.getLowerBoundMap().getNumResults() != 1)
return failure();
// Replaces all IV uses to its single iteration value.
auto iv = forOp.getInductionVar();
auto *parentBlock = forOp->getBlock();
if (!iv.use_empty()) {
if (forOp.hasConstantLowerBound()) {
OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
auto constOp = topBuilder.create<arith::ConstantIndexOp>(
forOp.getLoc(), forOp.getConstantLowerBound());
iv.replaceAllUsesWith(constOp);
} else {
auto lbOperands = forOp.getLowerBoundOperands();
auto lbMap = forOp.getLowerBoundMap();
OpBuilder builder(forOp);
if (lbMap == builder.getDimIdentityMap()) {
// No need of generating an affine.apply.
iv.replaceAllUsesWith(lbOperands[0]);
} else {
auto affineApplyOp =
builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
iv.replaceAllUsesWith(affineApplyOp);
}
}
}
replaceIterArgsAndYieldResults(forOp);
// Move the loop body operations, except for its terminator, to the loop's
// containing block.
forOp.getBody()->back().erase();
parentBlock->getOperations().splice(Block::iterator(forOp),
forOp.getBody()->getOperations());
forOp.erase();
return success();
}
/// Generates an affine.for op with the specified lower and upper bounds
/// while generating the right IV remappings to realize shifts for operations in
/// its body. The operations that go into the loop body are specified in
@@ -218,7 +156,9 @@ static AffineForOp generateShiftedLoop(
for (auto *op : ops)
bodyBuilder.clone(*op, operandMap);
};
if (succeeded(promoteIfSingleIteration(loopChunk)))
IRRewriter rewriter(loopChunk.getContext());
if (succeeded(loopChunk.promoteIfSingleIteration(rewriter)))
return AffineForOp();
return loopChunk;
}
@@ -892,12 +832,13 @@ void mlir::affine::getTileableBands(
/// Unrolls this loop completely.
LogicalResult mlir::affine::loopUnrollFull(AffineForOp forOp) {
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
IRRewriter rewriter(forOp.getContext());
if (mayBeConstantTripCount.has_value()) {
uint64_t tripCount = *mayBeConstantTripCount;
if (tripCount == 0)
return success();
if (tripCount == 1)
return promoteIfSingleIteration(forOp);
return forOp.promoteIfSingleIteration(rewriter);
return loopUnrollByFactor(forOp, tripCount);
}
return failure();
@@ -1003,7 +944,8 @@ static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp,
cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
// Promote the loop body up if this has turned into a single iteration loop.
(void)promoteIfSingleIteration(cleanupForOp);
IRRewriter rewriter(cleanupForOp.getContext());
(void)cleanupForOp.promoteIfSingleIteration(rewriter);
// Adjust upper bound of the original loop; this is the same as the lower
// bound of the cleanup loop.
@@ -1019,10 +961,11 @@ LogicalResult mlir::affine::loopUnrollByFactor(
bool cleanUpUnroll) {
assert(unrollFactor > 0 && "unroll factor should be positive");
IRRewriter rewriter(forOp.getContext());
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (unrollFactor == 1) {
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
failed(promoteIfSingleIteration(forOp)))
failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return success();
}
@@ -1076,7 +1019,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
/*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)promoteIfSingleIteration(forOp);
(void)forOp.promoteIfSingleIteration(rewriter);
return success();
}
@@ -1135,10 +1078,11 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
uint64_t unrollJamFactor) {
assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
IRRewriter rewriter(forOp.getContext());
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (unrollJamFactor == 1) {
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
failed(promoteIfSingleIteration(forOp)))
failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return success();
}
@@ -1198,7 +1142,6 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
// `unrollJamFactor` copies of its iterOperands, iter_args and yield
// operands.
SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
IRRewriter rewriter(forOp.getContext());
for (AffineForOp oldForOp : loopsWithIterArgs) {
SmallVector<Value> dupIterOperands, dupYieldOperands;
ValueRange oldIterOperands = oldForOp.getInits();
@@ -1321,7 +1264,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
}
// Promote the loop body up if this has turned into a single iteration loop.
(void)promoteIfSingleIteration(forOp);
(void)forOp.promoteIfSingleIteration(rewriter);
return success();
}

View File

@@ -552,7 +552,8 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
bool promoteSingleIter) {
if (promoteSingleIter && succeeded(promoteIfSingleIteration(op)))
IRRewriter rewriter(op.getContext());
if (promoteSingleIter && succeeded(op.promoteIfSingleIteration(rewriter)))
return success();
// Check if the forop is already normalized.

View File

@@ -107,13 +107,14 @@ void TestAffineDataCopy::runOnOperation() {
// Promote any single iteration loops in the copy nests and simplify
// load/stores.
IRRewriter rewriter(&getContext());
SmallVector<Operation *, 4> copyOps;
for (Operation *nest : copyNests) {
// With a post order walk, the erasure of loops does not affect
// continuation of the walk or the collection of load/store ops.
nest->walk([&](Operation *op) {
if (auto forOp = dyn_cast<AffineForOp>(op))
(void)promoteIfSingleIteration(forOp);
(void)forOp.promoteIfSingleIteration(rewriter);
else if (auto loadOp = dyn_cast<AffineLoadOp>(op))
copyOps.push_back(loadOp);
else if (auto storeOp = dyn_cast<AffineStoreOp>(op))