//===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the dataflow analysis class for integer range inference // which is used in transformations over the `arith` dialect such as // branch elimination or signed->unsigned rewriting // //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Support/DebugStringHelper.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" #include #include #include #define DEBUG_TYPE "int-range-analysis" using namespace mlir; using namespace mlir::dataflow; namespace mlir::dataflow { LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { auto *result = solver.lookupState(v); if (!result || result->getValue().isUninitialized()) return failure(); const ConstantIntRanges &range = result->getValue().getValue(); return success(range.smin().isNonNegative()); } LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) { auto nonNegativePred = [&solver](Value v) -> bool { return succeeded(staticallyNonNegative(solver, v)); }; return success(llvm::all_of(op->getOperands(), nonNegativePred) && llvm::all_of(op->getResults(), nonNegativePred)); } } // namespace mlir::dataflow void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { Lattice::onUpdate(solver); // If the integer range can be narrowed to a constant, update the constant // value of the SSA value. std::optional constant = getValue().getValue().getConstantValue(); auto value = cast(anchor); auto *cv = solver->getOrCreateState>(value); if (!constant) return solver->propagateIfChanged( cv, cv->join(ConstantValue::getUnknownConstant())); Dialect *dialect; if (auto *parent = value.getDefiningOp()) dialect = parent->getDialect(); else dialect = value.getParentBlock()->getParentOp()->getDialect(); Attribute cstAttr; if (isa(value.getType())) { cstAttr = IntegerAttr::get(value.getType(), *constant); } else if (auto shapedTy = dyn_cast(value.getType())) { cstAttr = SplatElementsAttr::get(shapedTy, *constant); } else { llvm::report_fatal_error( Twine("FIXME: Don't know how to create a constant for this type: ") + mlir::debugString(value.getType())); } solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect))); } LogicalResult IntegerRangeAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { auto inferrable = dyn_cast(op); if (!inferrable) { setAllToEntryStates(results); return success(); } LDBG() << "Inferring ranges for " << OpWithFlags(op, OpPrintingFlags().skipRegions()); auto argRanges = llvm::map_to_vector( operands, [](const IntegerValueRangeLattice *lattice) { return lattice->getValue(); }); auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { auto result = dyn_cast(v); if (!result) return; assert(llvm::is_contained(op->getResults(), result)); LDBG() << "Inferred range " << attrs; IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; IntegerValueRange oldRange = lattice->getValue(); ChangeResult changed = lattice->join(attrs); // Catch loop results with loop variant bounds and conservatively make // them [-inf, inf] so we don't circle around infinitely often (because // the dataflow analysis in MLIR doesn't attempt to work out trip counts // and often can't). bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { return op->hasTrait(); }); if (isYieldedResult && !oldRange.isUninitialized() && !(lattice->getValue() == oldRange)) { LDBG() << "Loop variant loop result detected"; changed |= lattice->join(IntegerValueRange::getMaxRange(v)); } propagateIfChanged(lattice, changed); }; inferrable.inferResultRangesFromOptional(argRanges, joinCallback); return success(); } void IntegerRangeAnalysis::visitNonControlFlowArguments( Operation *op, const RegionSuccessor &successor, ArrayRef argLattices, unsigned firstIndex) { if (auto inferrable = dyn_cast(op)) { LDBG() << "Inferring ranges for " << OpWithFlags(op, OpPrintingFlags().skipRegions()); auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); }); auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { auto arg = dyn_cast(v); if (!arg) return; if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) return; LDBG() << "Inferred range " << attrs; IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; IntegerValueRange oldRange = lattice->getValue(); ChangeResult changed = lattice->join(attrs); // Catch loop results with loop variant bounds and conservatively make // them [-inf, inf] so we don't circle around infinitely often (because // the dataflow analysis in MLIR doesn't attempt to work out trip counts // and often can't). bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { return op->hasTrait(); }); if (isYieldedValue && !oldRange.isUninitialized() && !(lattice->getValue() == oldRange)) { LDBG() << "Loop variant loop result detected"; changed |= lattice->join(IntegerValueRange::getMaxRange(v)); } propagateIfChanged(lattice, changed); }; inferrable.inferResultRangesFromOptional(argRanges, joinCallback); return; } /// Given a lower bound, upper bound, or step from a LoopLikeInterface return /// the lower/upper bound for that result if possible. auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType, Block *block, bool getUpper) { unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); if (auto attr = dyn_cast(loopBound)) { if (auto bound = dyn_cast(attr)) return bound.getValue(); } else if (auto value = llvm::dyn_cast(loopBound)) { const IntegerValueRangeLattice *lattice = getLatticeElementFor(getProgramPointBefore(block), value); if (lattice != nullptr && !lattice->getValue().isUninitialized()) return getUpper ? lattice->getValue().getValue().smax() : lattice->getValue().getValue().smin(); } // Given the results of getConstant{Lower,Upper}Bound() // or getConstantStep() on a LoopLikeInterface return the lower/upper // bound return getUpper ? APInt::getSignedMaxValue(width) : APInt::getSignedMinValue(width); }; // Infer bounds for loop arguments that have static bounds if (auto loop = dyn_cast(op)) { std::optional> maybeIvs = loop.getLoopInductionVars(); if (!maybeIvs) { return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( op, successor, argLattices, firstIndex); } // This shouldn't be returning nullopt if there are indunction variables. SmallVector lowerBounds = *loop.getLoopLowerBounds(); SmallVector upperBounds = *loop.getLoopUpperBounds(); SmallVector steps = *loop.getLoopSteps(); for (auto [iv, lowerBound, upperBound, step] : llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) { Block *block = iv.getParentBlock(); APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block, /*getUpper=*/false); APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block, /*getUpper=*/true); // Assume positivity for uniscoverable steps by way of getUpper = true. APInt stepVal = getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true); if (stepVal.isNegative()) { std::swap(min, max); } else { // Correct the upper bound by subtracting 1 so that it becomes a <= // bound, because loops do not generally include their upper bound. max -= 1; } // If we infer the lower bound to be larger than the upper bound, the // resulting range is meaningless and should not be used in further // inferences. if (max.sge(min)) { IntegerValueRangeLattice *ivEntry = getLatticeElement(iv); auto ivRange = ConstantIntRanges::fromSigned(min, max); propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); } } return; } return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( op, successor, argLattices, firstIndex); }