Add FlatAffineConstraints::containsId to avoid using findId when position isn't

needed + other cleanup
- clean up unionBoundingBox (hoist SmallVector allocations out of loop).

PiperOrigin-RevId: 237141668
This commit is contained in:
Uday Bondhugula
2019-03-06 16:18:27 -08:00
committed by jpienaar
parent 9e425a06f7
commit b8b15c7700
3 changed files with 26 additions and 19 deletions

View File

@@ -436,11 +436,15 @@ public:
/// constant. Asserts if the 'id' is not found.
void setIdToConstant(const Value &id, int64_t val);
/// Looks up the identifier with the specified Value. Returns false if not
/// found, true if found. pos is set to the (column) position of the
/// identifier.
/// Looks up the position of the identifier with the specified Value. Returns
/// true if found (false otherwise). `pos' is set to the (column) position of
/// the identifier.
bool findId(const Value &id, unsigned *pos) const;
/// Returns true if an identifier with the specified Value exists, false
/// otherwise.
bool containsId(const Value &id) const;
// Add identifiers of the specified kind - specified positions are relative to
// the kind of identifier. The coefficient column corresponding to the added
// identifier is initialized to zero. 'id' is the Value corresponding to the

View File

@@ -1888,6 +1888,12 @@ bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const {
return false;
}
bool FlatAffineConstraints::containsId(const Value &id) const {
return llvm::any_of(ids, [&](const Optional<Value *> &mayBeId) {
return mayBeId.hasValue() && mayBeId.getValue() == &id;
});
}
void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
assert(newSymbolCount <= numDims + numSymbols &&
"invalid separation position");
@@ -2696,19 +2702,21 @@ bool FlatAffineConstraints::unionBoundingBox(
boundingLbs.reserve(2 * getNumDimIds());
boundingUbs.reserve(2 * getNumDimIds());
SmallVector<int64_t, 4> lb, otherLb;
lb.reserve(getNumSymbolIds() + 1);
otherLb.reserve(getNumSymbolIds() + 1);
// To hold lower and upper bounds for each dimension.
SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
// To compute min of lower bounds and max of upper bounds for each dimension.
SmallVector<int64_t, 4> minLb, maxUb;
// To compute final new lower and upper bounds for the union.
SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
int64_t lbDivisor, otherLbDivisor;
for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
lb.clear();
auto extent = getConstantBoundOnDimSize(d, &lb, &lbDivisor);
if (!extent.hasValue())
// TODO(bondhugula): symbolic extents when necessary.
// TODO(bondhugula): handle union if a dimension is unbounded.
return false;
otherLb.clear();
auto otherExtent =
other.getConstantBoundOnDimSize(d, &otherLb, &otherLbDivisor);
if (!otherExtent.hasValue() || lbDivisor != otherLbDivisor)
@@ -2717,9 +2725,6 @@ bool FlatAffineConstraints::unionBoundingBox(
assert(lbDivisor > 0 && "divisor always expected to be positive");
// Compute min of lower bounds and max of upper bounds.
SmallVector<int64_t, 4> minLb, maxUb;
auto res = compareBounds(lb, otherLb);
// Identify min.
if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
@@ -2737,7 +2742,8 @@ bool FlatAffineConstraints::unionBoundingBox(
}
// Do the same for ub's but max of upper bounds.
SmallVector<int64_t, 4> ub(lb), otherUb(otherLb);
ub = lb;
otherUb = otherLb;
ub.back() += extent.getValue() - 1;
otherUb.back() += otherExtent.getValue() - 1;
@@ -2757,8 +2763,8 @@ bool FlatAffineConstraints::unionBoundingBox(
maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
}
SmallVector<int64_t, 8> newLb(getNumCols(), 0);
SmallVector<int64_t, 8> newUb(getNumCols(), 0);
std::fill(newLb.begin(), newLb.end(), 0);
std::fill(newUb.begin(), newUb.end(), 0);
// The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
// and so it's the divisor for newLb and newUb as well.

View File

@@ -70,9 +70,7 @@ bool ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
// Add loop bound constraints for values which are loop IVs and equality
// constraints for symbols which are constants.
for (const auto &value : values) {
unsigned loc;
(void)loc;
assert(cst->findId(*value, &loc));
assert(cst->containsId(*value) && "value expected to be present");
if (isValidSymbol(value)) {
// Check if the symbol is a constant.
if (auto *inst = value->getDefiningInst()) {
@@ -256,8 +254,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
if (sliceState != nullptr) {
// Add dim and symbol slice operands.
for (const auto &operand : sliceState->lbOperands[0]) {
unsigned loc;
if (!cst.findId(*operand, &loc)) {
if (!cst.containsId(*operand)) {
if (isValidSymbol(operand)) {
cst.addSymbolId(cst.getNumSymbolIds(), const_cast<Value *>(operand));
// Check if the symbol is a constant.