mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
[mlir][IntegerRangeAnalysis] Handle multi-dimensional loops (#170765)
Since LoopLikeInterface has (for some time) been extended to handle multiple induction variables (and thus lower and upper bounds), handle those bounds one at a time.
This commit is contained in:
committed by
GitHub
parent
5e4974fbd3
commit
ad1edc9cbc
@@ -180,23 +180,20 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
|
||||
return;
|
||||
}
|
||||
|
||||
/// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
|
||||
/// on a LoopLikeInterface return the lower/upper bound for that result if
|
||||
/// possible.
|
||||
auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
|
||||
Type boundType, Block *block, bool getUpper) {
|
||||
/// Given a lower bound, upper bound, or step from a LoopLikeInterface return
|
||||
/// the lower/upper bound for that result if possible.
|
||||
auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType,
|
||||
Block *block, bool getUpper) {
|
||||
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
|
||||
if (loopBound.has_value()) {
|
||||
if (auto attr = dyn_cast<Attribute>(*loopBound)) {
|
||||
if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
|
||||
return bound.getValue();
|
||||
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
|
||||
const IntegerValueRangeLattice *lattice =
|
||||
getLatticeElementFor(getProgramPointBefore(block), value);
|
||||
if (lattice != nullptr && !lattice->getValue().isUninitialized())
|
||||
return getUpper ? lattice->getValue().getValue().smax()
|
||||
: lattice->getValue().getValue().smin();
|
||||
}
|
||||
if (auto attr = dyn_cast<Attribute>(loopBound)) {
|
||||
if (auto bound = dyn_cast<IntegerAttr>(attr))
|
||||
return bound.getValue();
|
||||
} else if (auto value = llvm::dyn_cast<Value>(loopBound)) {
|
||||
const IntegerValueRangeLattice *lattice =
|
||||
getLatticeElementFor(getProgramPointBefore(block), value);
|
||||
if (lattice != nullptr && !lattice->getValue().isUninitialized())
|
||||
return getUpper ? lattice->getValue().getValue().smax()
|
||||
: lattice->getValue().getValue().smin();
|
||||
}
|
||||
// Given the results of getConstant{Lower,Upper}Bound()
|
||||
// or getConstantStep() on a LoopLikeInterface return the lower/upper
|
||||
@@ -207,38 +204,43 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
|
||||
|
||||
// Infer bounds for loop arguments that have static bounds
|
||||
if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
|
||||
std::optional<Value> iv = loop.getSingleInductionVar();
|
||||
if (!iv) {
|
||||
std::optional<llvm::SmallVector<Value>> maybeIvs =
|
||||
loop.getLoopInductionVars();
|
||||
if (!maybeIvs) {
|
||||
return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
|
||||
op, successor, argLattices, firstIndex);
|
||||
}
|
||||
Block *block = iv->getParentBlock();
|
||||
std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
|
||||
std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
|
||||
std::optional<OpFoldResult> step = loop.getSingleStep();
|
||||
APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
|
||||
/*getUpper=*/false);
|
||||
APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
|
||||
/*getUpper=*/true);
|
||||
// Assume positivity for uniscoverable steps by way of getUpper = true.
|
||||
APInt stepVal =
|
||||
getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
|
||||
// This shouldn't be returning nullopt if there are indunction variables.
|
||||
SmallVector<OpFoldResult> lowerBounds = *loop.getLoopLowerBounds();
|
||||
SmallVector<OpFoldResult> upperBounds = *loop.getLoopUpperBounds();
|
||||
SmallVector<OpFoldResult> steps = *loop.getLoopSteps();
|
||||
for (auto [iv, lowerBound, upperBound, step] :
|
||||
llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) {
|
||||
Block *block = iv.getParentBlock();
|
||||
APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block,
|
||||
/*getUpper=*/false);
|
||||
APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block,
|
||||
/*getUpper=*/true);
|
||||
// Assume positivity for uniscoverable steps by way of getUpper = true.
|
||||
APInt stepVal =
|
||||
getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true);
|
||||
|
||||
if (stepVal.isNegative()) {
|
||||
std::swap(min, max);
|
||||
} else {
|
||||
// Correct the upper bound by subtracting 1 so that it becomes a <=
|
||||
// bound, because loops do not generally include their upper bound.
|
||||
max -= 1;
|
||||
}
|
||||
if (stepVal.isNegative()) {
|
||||
std::swap(min, max);
|
||||
} else {
|
||||
// Correct the upper bound by subtracting 1 so that it becomes a <=
|
||||
// bound, because loops do not generally include their upper bound.
|
||||
max -= 1;
|
||||
}
|
||||
|
||||
// If we infer the lower bound to be larger than the upper bound, the
|
||||
// resulting range is meaningless and should not be used in further
|
||||
// inferences.
|
||||
if (max.sge(min)) {
|
||||
IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
|
||||
auto ivRange = ConstantIntRanges::fromSigned(min, max);
|
||||
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
|
||||
// If we infer the lower bound to be larger than the upper bound, the
|
||||
// resulting range is meaningless and should not be used in further
|
||||
// inferences.
|
||||
if (max.sge(min)) {
|
||||
IntegerValueRangeLattice *ivEntry = getLatticeElement(iv);
|
||||
auto ivRange = ConstantIntRanges::fromSigned(min, max);
|
||||
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -184,3 +184,19 @@ func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_loop_ivs
|
||||
func.func @multiple_loop_ivs(%arg0: memref<?x64xi32>) {
|
||||
%ub1 = test.with_bounds { umin = 1 : index, umax = 32 : index,
|
||||
smin = 1 : index, smax = 32 : index } : index
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
// CHECK: scf.forall
|
||||
scf.forall (%arg1, %arg2) in (%ub1, 64) {
|
||||
// CHECK: test.reflect_bounds {smax = 31 : index, smin = 0 : index, umax = 31 : index, umin = 0 : index}
|
||||
%1 = test.reflect_bounds %arg1 : index
|
||||
// CHECK-NEXT: test.reflect_bounds {smax = 63 : index, smin = 0 : index, umax = 63 : index, umin = 0 : index}
|
||||
%2 = test.reflect_bounds %arg2 : index
|
||||
memref.store %c0_i32, %arg0[%1, %2] : memref<?x64xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user