Extend loop-fusion's slicing utility + other fixes / updates

- refactor toAffineFromEq and the code surrounding it; refactor code into
  FlatAffineConstraints::getSliceBounds
- add FlatAffineConstraints methods to detect identifiers as mod's and div's of other
  identifiers
- add FlatAffineConstraints::getConstantLower/UpperBound
- Address b/122118218 (don't assert on invalid fusion depths cmdline flags -
  instead, don't do anything; change cmdline flags
  src-loop-depth -> fusion-src-loop-depth
- AffineExpr/Map print method update: don't fail on null instances (since we have
  a wrapper around a pointer, it's avoidable); rationale: dump/print methods should
  never fail if possible.
- Update memref-dataflow-opt to add an optimization to avoid a unnecessary call to
  IsRangeOneToOne when it's trivially going to be true.
- Add additional test cases to exercise the new support
- update a few existing test cases since the maps are now generated uniformly with
  all destination loop operands appearing for the backward slice
- Fix projectOut - fix wrong range for getBestElimCandidate.
- Fix for getConstantBoundOnDimSize() - didn't show up in any test cases since
  we didn't have any non-hyperrectangular ones.

PiperOrigin-RevId: 228265152
This commit is contained in:
Uday Bondhugula
2019-01-07 17:34:26 -08:00
committed by jpienaar
parent b934d75b8f
commit 21baf86a2f
10 changed files with 714 additions and 190 deletions

View File

@@ -387,18 +387,15 @@ public:
AffineExpr toAffineExpr(unsigned idx, MLIRContext *context);
// Returns an AffineMap that expresses the identifier at pos as a function of
// other dimensional and symbolic identifiers.
// If 'nonZeroDimIds' and 'nonZeroSymbolIds' are non-null, they are populated
// with the positions of the non-zero equality constraint coefficients which
// were used to build the returned AffineMap.
// Returns AffineMap::Null if such an expression can't be constructed.
// TODO(andydavis) Remove 'nonZeroDimIds' and 'nonZeroSymbolIds' from this
// API when we can manage the mapping of Values and ids in the constraint
// system.
AffineMap toAffineMapFromEq(unsigned pos, MLIRContext *context,
SmallVectorImpl<unsigned> *nonZeroDimIds,
SmallVectorImpl<unsigned> *nonZeroSymbolIds);
/// Computes the lower and upper bounds of the first 'num' dimensional
/// identifiers as an affine map of the remaining identifiers (dimensional and
/// symbolic). This method is able to detect identifiers as floordiv's
/// and mod's of affine expressions of other identifiers with respect to
/// (positive) constants. Sets bound map to AffineMap::Null if such a bound
/// can't be found (or yet unimplemented).
void getSliceBounds(unsigned num, MLIRContext *context,
SmallVectorImpl<AffineMap> *lbMaps,
SmallVectorImpl<AffineMap> *ubMaps);
// Adds an inequality (>= 0) from the coefficients specified in inEq.
void addInequality(ArrayRef<int64_t> inEq);
@@ -513,6 +510,7 @@ public:
inline unsigned getNumIds() const { return numIds; }
inline unsigned getNumDimIds() const { return numDims; }
inline unsigned getNumSymbolIds() const { return numSymbols; }
inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; }
inline unsigned getNumLocalIds() const {
return numIds - numDims - numSymbols;
}
@@ -521,24 +519,43 @@ public:
return {ids.data(), ids.size()};
}
/// Returns the Value's associated with the identifiers. Asserts if
/// no Value was associated with an identifier.
inline void getIdValues(SmallVectorImpl<Value *> *values) const {
values->clear();
values->reserve(numIds);
for (unsigned i = 0; i < numIds; i++) {
assert(ids[i].hasValue() && "identifier's Value not set");
values->push_back(ids[i].getValue());
}
}
/// Returns the Value associated with the pos^th identifier. Asserts if
/// no Value identifier was associated.
inline Value *getIdValue(unsigned pos) const {
assert(ids[pos].hasValue() && "identifier's ML Value not set");
assert(ids[pos].hasValue() && "identifier's Value not set");
return ids[pos].getValue();
}
/// Returns the Values associated with identifiers in range [start, end).
/// Asserts if no Value was associated with one of these identifiers.
void getIdValues(unsigned start, unsigned end,
SmallVectorImpl<Value *> *values) const {
assert((start < numIds || start == end) && "invalid start position");
assert(end <= numIds && "invalid end position");
values->clear();
values->reserve(end - start);
for (unsigned i = start; i < end; i++) {
values->push_back(getIdValue(i));
}
}
inline void getAllIdValues(SmallVectorImpl<Value *> *values) const {
getIdValues(0, numIds, values);
}
/// Sets Value associated with the pos^th identifier.
inline void setIdValue(unsigned pos, Value *val) {
assert(pos < numIds && "invalid id position");
ids[pos] = val;
}
/// Sets Values associated with identifiers in the range [start, end).
void setIdValues(unsigned start, unsigned end, ArrayRef<Value *> values) {
assert((start < numIds || end == start) && "invalid start position");
assert(end <= numIds && "invalid end position");
assert(values.size() == end - start);
for (unsigned i = start; i < end; ++i)
ids[i] = values[i - start];
}
/// Clears this list of constraints and copies other into it.
void clearAndCopyFrom(const FlatAffineConstraints &other);
@@ -555,6 +572,14 @@ public:
getConstantBoundOnDimSize(unsigned pos,
SmallVectorImpl<int64_t> *lb = nullptr) const;
/// Returns the constant lower bound for the pos^th identifier if there is
/// one; None otherwise.
Optional<int64_t> getConstantLowerBound(unsigned pos) const;
/// Returns the constant upper bound for the pos^th identifier if there is
/// one; None otherwise.
Optional<int64_t> getConstantUpperBound(unsigned pos) const;
/// Returns true if the set can be trivially detected as being
/// hyper-rectangular on the specified contiguous set of identifiers.
bool isHyperRectangular(unsigned pos, unsigned num) const;
@@ -579,6 +604,11 @@ private:
/// 'false'otherwise.
bool hasInvalidConstraint() const;
/// Returns the constant lower bound bound if isLower is true, and the upper
/// bound if isLower is false.
template <bool isLower>
Optional<int64_t> getConstantLowerOrUpperBound(unsigned pos) const;
// Eliminates a single identifier at 'position' from equality and inequality
// constraints. Returns 'true' if the identifier was eliminated, and false
// otherwise.

View File

@@ -83,6 +83,8 @@ public:
return *this;
}
static AffineExpr Null() { return AffineExpr(nullptr); }
bool operator==(AffineExpr other) const { return expr == other.expr; }
bool operator!=(AffineExpr other) const { return !(*this == other); }
explicit operator bool() const { return expr; }

View File

@@ -690,9 +690,9 @@ private:
// Builds a map from Value to identifier position in a new merged identifier
// list, which is the result of merging dim/symbol lists from src/dst
// iteration domains. The format of the new merged list is as follows:
// iteration domains, the format of which is as follows:
//
// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers]
// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term]
//
// This method populates 'valuePosMap' with mappings from operand Values in
// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
@@ -700,22 +700,26 @@ private:
static void buildDimAndSymbolPositionMaps(
const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap) {
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
FlatAffineConstraints *dependenceConstraints) {
auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto *value = values[i];
if (!isa<ForInst>(values[i]))
if (!isa<ForInst>(values[i])) {
assert(values[i]->isValidSymbol() &&
"access operand has to be either a loop IV or a symbol");
valuePosMap->addSymbolValue(value);
else if (isSrc)
} else if (isSrc) {
valuePosMap->addSrcValue(value);
else
} else {
valuePosMap->addDstValue(value);
}
}
};
SmallVector<Value *, 4> srcValues, destValues;
srcDomain.getIdValues(&srcValues);
dstDomain.getIdValues(&destValues);
srcDomain.getAllIdValues(&srcValues);
dstDomain.getAllIdValues(&destValues);
// Update value position map with identifiers from src iteration domain.
updateValuePosMap(srcValues, /*isSrc=*/true);
@@ -727,6 +731,65 @@ static void buildDimAndSymbolPositionMaps(
updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
}
// Sets up dependence constraints columns appropriately, in the format:
// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term]
void initDependenceConstraints(const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain,
const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap,
const ValuePositionMap &valuePosMap,
FlatAffineConstraints *dependenceConstraints) {
// Calculate number of equalities/inequalities and columns required to
// initialize FlatAffineConstraints for 'dependenceDomain'.
unsigned numIneq =
srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
AffineMap srcMap = srcAccessMap.getAffineMap();
assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
unsigned numEq = srcMap.getNumResults();
unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
unsigned numSymbols = valuePosMap.getNumSymbols();
unsigned numIds = numDims + numSymbols;
unsigned numCols = numIds + 1;
// Set flat affine constraints sizes and reserving space for constraints.
dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
/*numLocals=*/0);
// Set values corresponding to dependence constraint identifiers.
SmallVector<Value *, 4> srcLoopIVs, dstLoopIVs;
srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs);
dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs);
dependenceConstraints->setIdValues(0, srcLoopIVs.size(), srcLoopIVs);
dependenceConstraints->setIdValues(
srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
// Set values for the symbolic identifier dimensions.
auto setSymbolIds = [&](ArrayRef<Value *> values) {
for (auto *value : values) {
if (!isa<ForInst>(value)) {
assert(value->isValidSymbol() && "expected symbol");
dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
}
}
};
setSymbolIds(srcAccessMap.getOperands());
setSymbolIds(dstAccessMap.getOperands());
SmallVector<Value *, 8> srcSymbolValues, dstSymbolValues;
srcDomain.getIdValues(srcDomain.getNumDimIds(),
srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
dstDomain.getIdValues(dstDomain.getNumDimIds(),
dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
setSymbolIds(srcSymbolValues);
setSymbolIds(dstSymbolValues);
for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds();
i < e; i++)
assert(dependenceConstraints->getIds()[i].hasValue());
}
// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
// 'dependenceDomain'.
// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
@@ -1278,25 +1341,15 @@ bool mlir::checkMemrefAccessDependence(
// Value to position in merged contstraint system.
ValuePositionMap valuePosMap;
buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
dstAccessMap, &valuePosMap);
dstAccessMap, &valuePosMap,
dependenceConstraints);
initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
valuePosMap, dependenceConstraints);
assert(valuePosMap.getNumDims() ==
srcDomain.getNumDimIds() + dstDomain.getNumDimIds());
// Calculate number of equalities/inequalities and columns required to
// initialize FlatAffineConstraints for 'dependenceDomain'.
unsigned numIneq =
srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
AffineMap srcMap = srcAccessMap.getAffineMap();
assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
unsigned numEq = srcMap.getNumResults();
unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
unsigned numSymbols = valuePosMap.getNumSymbols();
unsigned numIds = numDims + numSymbols;
unsigned numCols = numIds + 1;
// Create flat affine constraints reserving space for 'numEq' and 'numIneq'.
dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
/*numLocals=*/0);
// Create memref access constraint by equating src/dst access functions.
// Note that this check is conservative, and will failure in the future
// when local variables for mod/div exprs are supported.

View File

@@ -1137,50 +1137,268 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
return posLimit - posStart;
}
AffineMap FlatAffineConstraints::toAffineMapFromEq(
unsigned pos, MLIRContext *context,
SmallVectorImpl<unsigned> *nonZeroDimIds,
SmallVectorImpl<unsigned> *nonZeroSymbolIds) {
// Detect the identifier at 'pos' (say id_r) as modulo of another identifier
// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
// could be detected as the floordiv of n. For eg:
// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=>
// id_r = id_n mod 4, id_q = id_n floordiv 4.
// lbConst and ubConst are the constant lower and upper bounds for 'pos' -
// pre-detected at the caller.
static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
int64_t lbConst, int64_t ubConst,
SmallVectorImpl<AffineExpr> *memo) {
assert(pos < cst.getNumIds() && "invalid position");
// For now just project out local IDs, and return null if we can't
// find an equality. TODO(bondhugula): infer as a function of other
// dims/symbols involving mod/div.
projectOut(getNumIds() - getNumLocalIds(), getNumLocalIds());
// Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
// id_n - divisor * id_q. If these are true, then id_n becomes the dividend
// and id_q the quotient when dividing id_n by the divisor.
unsigned idx;
if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx))
return AffineMap::Null();
if (lbConst != 0 || ubConst < 1)
return false;
// Build AffineExpr solving for identifier 'pos' in terms of all others.
auto expr = getAffineConstantExpr(0, context);
unsigned mapNumDims = 0;
unsigned mapNumSymbols = 0;
for (unsigned j = 0, e = getNumIds(); j < e; ++j) {
if (j == pos)
int64_t divisor = ubConst + 1;
// Now check for: id_r = id_n - divisor * id_q. As an example, we
// are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
unsigned seenQuotient = 0, seenDividend = 0;
int quotientPos = -1, dividendPos = -1;
for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
// id_n should have coeff 1 or -1.
if (std::abs(cst.atEq(r, pos)) != 1)
continue;
int64_t c = atEq(idx, j);
if (c == 0)
continue;
if (j < numDims) {
expr = expr - getAffineDimExpr(mapNumDims++, context) * c;
nonZeroDimIds->push_back(j);
} else {
expr =
expr - getAffineSymbolExpr(mapNumDims + mapNumSymbols++, context) * c;
nonZeroSymbolIds->push_back(j);
for (unsigned c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
// The coeff of the quotient should be -divisor if the coefficient of
// the pos^th identifier is -1, and divisor if the latter is -1.
if (cst.atEq(r, c) * cst.atEq(r, pos) == divisor) {
seenQuotient++;
quotientPos = c;
} else if (cst.atEq(r, c) * cst.atEq(r, pos) == -1) {
seenDividend++;
dividendPos = c;
}
}
// We are looking for exactly one identifier as part of the dividend.
// TODO(bondhugula): could be extended to cover multiple ones in the
// dividend to detect mod of an affine function of identifiers.
if (seenDividend == 1 && seenQuotient >= 1) {
if (!(*memo)[dividendPos])
return false;
// Successfully detected a mod.
(*memo)[pos] = (*memo)[dividendPos] % divisor;
if (seenQuotient == 1 && !(*memo)[quotientPos])
// Successfully detected a floordiv as well.
(*memo)[quotientPos] = (*memo)[dividendPos].floorDiv(divisor);
return true;
}
}
// Add constant term to AffineExpr.
expr = expr - atEq(idx, getNumIds());
int64_t v = atEq(idx, pos);
assert(v != 0 && "expected non-zero here");
if (v > 0)
expr = expr.floorDiv(v);
else
// v < 0.
expr = (-expr).floorDiv(-v);
return false;
}
return AffineMap::get(mapNumDims, mapNumSymbols, {expr}, {});
// Check if the pos^th identifier can be expressed as a floordiv of an affine
// function of other identifiers (where the divisor is a positive constant).
// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4.
bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) {
assert(pos < cst.getNumIds() && "invalid position");
SmallVector<unsigned, 4> lbIndices, ubIndices;
// Gather all lower bounds and upper bound constraints of this identifier.
// Since the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint
// is a lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
if (cst.atIneq(r, pos) >= 1)
// Lower bound.
lbIndices.push_back(r);
else if (cst.atIneq(r, pos) <= -1)
// Upper bound.
ubIndices.push_back(r);
}
// Check if any lower bound, upper bound pair is of the form:
// divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id'
// divisor * id <= expr <-- Upper bound for 'id'
// Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1).
//
// For example, if -32*k + 16*i + j >= 0
// 32*k - 16*i - j + 31 >= 0 <=>
// k = ( 16*i + j ) floordiv 32
unsigned seenDividends = 0;
for (auto ubPos : ubIndices) {
for (auto lbPos : lbIndices) {
// Check if lower bound's constant term is 'divisor - 1'. The 'divisor'
// here is cst.atIneq(lbPos, pos) and we already know that it's positive
// (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'.
if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1)
continue;
// Check if upper bound's constant term is 0.
if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0)
continue;
// For the remaining part, check if the lower bound expr's coeff's are
// negations of corresponding upper bound ones'.
unsigned c, f;
for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c))
break;
if (c != pos && cst.atIneq(lbPos, c) != 0)
seenDividends++;
}
// Lb coeff's aren't negative of ub coeff's (for the non constant term
// part).
if (c < f)
continue;
if (seenDividends >= 1) {
// The divisor is the constant term of the lower bound expression.
// We already know that cst.atIneq(lbPos, pos) > 0.
int64_t divisor = cst.atIneq(lbPos, pos);
// Construct the dividend expression.
auto dividendExpr = getAffineConstantExpr(0, context);
unsigned c, f;
for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
if (c == pos)
continue;
int64_t ubVal = cst.atIneq(ubPos, c);
if (ubVal == 0)
continue;
if (!(*memo)[c])
break;
dividendExpr = dividendExpr + ubVal * (*memo)[c];
}
// Expression can't be constructed as it depends on a yet unknown
// identifier.
// TODO(mlir-team): Visit/compute the identifiers in an order so that
// this doesn't happen. More complex but much more efficient.
if (c < f)
continue;
// Successfully detected the floordiv.
(*memo)[pos] = dividendExpr.floorDiv(divisor);
return true;
}
}
}
return false;
}
/// Computes the lower and upper bounds of the first 'num' dimensional
/// identifiers as affine maps of the remaining identifiers (dimensional and
/// symbolic identifiers). Local identifiers are themselves explicitly computed
/// as affine functions of other identifiers in this process if needed.
void FlatAffineConstraints::getSliceBounds(unsigned num, MLIRContext *context,
SmallVectorImpl<AffineMap> *lbMaps,
SmallVectorImpl<AffineMap> *ubMaps) {
assert(num < getNumDimIds() && "invalid range");
// Basic simplification.
normalizeConstraintsByGCD();
LLVM_DEBUG(llvm::dbgs() << "getSliceBounds on:\n");
LLVM_DEBUG(dump());
// Record computed/detected identifiers.
SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr::Null());
// Initialize dimensional and symbolic identifiers.
for (unsigned i = num, e = getNumDimIds(); i < e; i++)
memo[i] = getAffineDimExpr(i - num, context);
for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
bool changed;
do {
changed = false;
// Identify yet unknown identifiers as constants or mod's / floordiv's of
// other identifiers if possible.
for (unsigned pos = 0; pos < getNumIds(); pos++) {
if (memo[pos])
continue;
auto lbConst = getConstantLowerBound(pos);
auto ubConst = getConstantUpperBound(pos);
if (lbConst.hasValue() && ubConst.hasValue()) {
// Detect equality to a constant.
if (lbConst.getValue() == ubConst.getValue()) {
memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
changed = true;
continue;
}
// Detect an identifier as modulo of another identifier w.r.t a
// constant.
if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
&memo)) {
changed = true;
continue;
}
}
// Detect an identifier as floordiv of another identifier w.r.t a
// constant.
if (detectAsFloorDiv(*this, pos, &memo, context)) {
changed = true;
continue;
}
// Detect an identifier as an expression of other identifiers.
unsigned idx;
if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
continue;
}
// Build AffineExpr solving for identifier 'pos' in terms of all others.
auto expr = getAffineConstantExpr(0, context);
unsigned j, e;
for (j = 0, e = getNumIds(); j < e; ++j) {
if (j == pos)
continue;
int64_t c = atEq(idx, j);
if (c == 0)
continue;
// If any of the involved IDs hasn't been found yet, we can't proceed.
if (!memo[j])
break;
expr = expr + memo[j] * c;
}
if (j < e)
// Can't construct expression as it depends on a yet uncomputed
// identifier.
continue;
// Add constant term to AffineExpr.
expr = expr + atEq(idx, getNumIds());
int64_t vPos = atEq(idx, pos);
assert(vPos != 0 && "expected non-zero here");
if (vPos > 0)
expr = (-expr).floorDiv(vPos);
else
// vPos < 0.
expr = expr.floorDiv(-vPos);
// Successfully constructed expression.
memo[pos] = expr;
changed = true;
}
// This loop is guaranteed to reach a fixed point - since once an
// identifier's explicit form is computed (in memo[pos]), it's not updated
// again.
} while (changed);
// Set the lower and upper bound maps for all the identifiers that were
// computed as affine expressions of the rest as the "detected expr" and
// "detected expr + 1" respectively; set the undetected ones to Null().
for (unsigned pos = 0; pos < num; pos++) {
unsigned numMapDims = getNumDimIds() - num;
unsigned numMapSymbols = getNumSymbolIds();
AffineExpr expr = memo[pos];
if (expr)
expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
if (expr) {
(*lbMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr, {});
(*ubMaps)[pos] = AffineMap::get(numMapDims, numMapSymbols, expr + 1, {});
} else {
(*lbMaps)[pos] = AffineMap::Null();
(*ubMaps)[pos] = AffineMap::Null();
}
LLVM_DEBUG(llvm::dbgs() << "lb map for pos = " << Twine(pos) << ", expr: ");
LLVM_DEBUG(expr.dump(););
}
}
void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
@@ -1456,7 +1674,7 @@ bool FlatAffineConstraints::constantFoldId(unsigned pos) {
// atEq(rowIdx, pos) is either -1 or 1.
assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
int64_t constVal = atEq(rowIdx, getNumCols() - 1) / -atEq(rowIdx, pos);
int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
setAndEliminate(pos, constVal);
return true;
}
@@ -1513,19 +1731,24 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
if (atIneq(r, pos) != 0)
break;
}
if (r == e) {
// If it doesn't appear, just remove the column and return.
// TODO(andydavis,bondhugula): refactor removeColumns to use it from here.
if (r == e)
// If it doesn't, there isn't a bound on it.
return None;
}
// Positions of constraints that are lower/upper bounds on the variable.
SmallVector<unsigned, 4> lbIndices, ubIndices;
// Gather all lower bounds and upper bounds of the variable. Since the
// canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
// bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
// Gather all symbolic lower bounds and upper bounds of the variable. Since
// the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a
// lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
unsigned c, f;
for (c = 0, f = getNumDimIds(); c < f; c++) {
if (c != pos && atIneq(r, c) != 0)
break;
}
if (c < getNumDimIds())
continue;
if (atIneq(r, pos) >= 1)
// Lower bound.
lbIndices.push_back(r);
@@ -1554,10 +1777,10 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
}
if (j < getNumCols() - 1)
continue;
int64_t mayDiff =
int64_t diff =
atIneq(ubPos, getNumCols() - 1) + atIneq(lbPos, getNumCols() - 1) + 1;
if (minDiff == None || mayDiff < minDiff) {
minDiff = mayDiff;
if (minDiff == None || diff < minDiff) {
minDiff = diff;
minLbPosition = lbPos;
}
}
@@ -1572,6 +1795,71 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
return minDiff;
}
template <bool isLower>
Optional<int64_t>
FlatAffineConstraints::getConstantLowerOrUpperBound(unsigned pos) const {
// Check if there's an equality equating the 'pos'^th identifier to a
// constant.
int eqRowIdx = findEqualityToConstant(*this, pos, /*symbolic=*/false);
if (eqRowIdx != -1)
// atEq(rowIdx, pos) is either -1 or 1.
return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, pos);
// Check if the identifier appears at all in any of the inequalities.
unsigned r, e;
for (r = 0, e = getNumInequalities(); r < e; r++) {
if (atIneq(r, pos) != 0)
break;
}
if (r == e)
// If it doesn't, there isn't a bound on it.
return None;
Optional<int64_t> minOrMaxConst = None;
// Take the max across all const lower bounds (or min across all constant
// upper bounds).
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
if (isLower) {
if (atIneq(r, pos) <= 0)
// Not a lower bound.
continue;
} else if (atIneq(r, pos) >= 0) {
// Not an upper bound.
continue;
}
unsigned c, f;
for (c = 0, f = getNumCols() - 1; c < f; c++)
if (c != pos && atIneq(r, c) != 0)
break;
if (c < getNumCols() - 1)
// Not a constant bound.
continue;
int64_t boundConst =
isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, pos))
: mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, pos));
if (isLower) {
if (minOrMaxConst == None || boundConst > minOrMaxConst)
minOrMaxConst = boundConst;
} else {
if (minOrMaxConst == None || boundConst < minOrMaxConst)
minOrMaxConst = boundConst;
}
}
return minOrMaxConst;
}
Optional<int64_t>
FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
return getConstantLowerOrUpperBound</*isLower=*/true>(pos);
}
Optional<int64_t>
FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
return getConstantLowerOrUpperBound</*isLower=*/false>(pos);
}
// A simple (naive and conservative) check for hyper-rectangularlity.
bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
unsigned num) const {
@@ -1912,7 +2200,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
return;
// 'pos' can be at most getNumCols() - 2 if num > 0.
assert(pos <= getNumCols() - 2 && "invalid position");
assert(getNumCols() < 2 || pos <= getNumCols() - 2 && "invalid position");
assert(pos + num < getNumCols() && "invalid range");
// Eliminate as many identifiers as possible using Gaussian elimination.
@@ -1930,8 +2218,9 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
// Eliminate the remaining using Fourier-Motzkin.
for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
unsigned elimId = getBestIdToEliminate(*this, pos, getNumIds());
FourierMotzkinEliminate(elimId);
unsigned numToEliminate = num - numGaussianEliminated - i;
FourierMotzkinEliminate(
getBestIdToEliminate(*this, pos, pos + numToEliminate));
}
// Fast/trivial simplifications.

View File

@@ -340,6 +340,14 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
// dependence constraint system to create AffineMaps with which to adjust the
// loop bounds of the inserted compution slice so that they are functions of the
// loop IVs and symbols of the loops surrounding 'dstAccess'.
// TODO(andydavis,bondhugula): extend the slicing utility to compute slices that
// aren't necessarily a one-to-one relation b/w the source and destination. The
// relation between the source and destination could be many-to-many in general.
// TODO(andydavis,bondhugula): the slice computation is incorrect in the cases
// where the dependence from the source to the destination does not cover the
// entire destination index set. Subtract out the dependent destination
// iterations from destination index set and check for emptiness --- this is one
// solution.
ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
MemRefAccess *dstAccess,
unsigned srcLoopDepth,
@@ -351,89 +359,74 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
return nullptr;
}
// Get loop nest surrounding src operation.
SmallVector<ForInst *, 4> srcLoopNest;
getLoopIVs(*srcAccess->opInst, &srcLoopNest);
unsigned srcLoopNestSize = srcLoopNest.size();
assert(srcLoopDepth <= srcLoopNestSize);
// Get loop nest surrounding dst operation.
SmallVector<ForInst *, 4> dstLoopNest;
getLoopIVs(*dstAccess->opInst, &dstLoopNest);
unsigned dstLoopNestSize = dstLoopNest.size();
(void)dstLoopNestSize;
assert(dstLoopDepth > 0);
assert(dstLoopDepth <= dstLoopNestSize);
// Solve for src IVs in terms of dst IVs, symbols and constants.
SmallVector<AffineMap, 4> srcIvMaps(srcLoopNestSize, AffineMap::Null());
std::vector<SmallVector<Value *, 2>> srcIvOperands(srcLoopNestSize);
for (unsigned i = 0; i < srcLoopNestSize; ++i) {
// Skip IVs which are greater than requested loop depth.
if (i >= srcLoopDepth) {
srcIvMaps[i] = AffineMap::Null();
continue;
}
auto cst = dependenceConstraints.clone();
for (int j = srcLoopNestSize - 1; j >= 0; --j) {
if (i != j)
cst->projectOut(j);
}
SmallVector<unsigned, 2> nonZeroDimIds;
SmallVector<unsigned, 2> nonZeroSymbolIds;
srcIvMaps[i] = cst->toAffineMapFromEq(0, srcAccess->opInst->getContext(),
&nonZeroDimIds, &nonZeroSymbolIds);
// Add operands for all non-zero dst dims and symbols.
// TODO(andydavis) Add local variable support.
for (auto dimId : nonZeroDimIds) {
if (dimId - 1 >= dstLoopDepth) {
// This src IV has a dependence on dst IV dstLoopDepth where it will
// be inserted. So we cannot slice the iteration space at srcLoopDepth,
// and also insert it into the dst loop nest at 'dstLoopDepth'.
return nullptr;
}
srcIvOperands[i].push_back(dstLoopNest[dimId - 1]);
}
// TODO(andydavis) Add symbols from the access function. Ideally, we
// should be able to query the constaint system for the Value associated
// with a symbol identifiers in 'nonZeroSymbolIds'.
SmallVector<ForInst *, 4> srcLoopIVs;
getLoopIVs(*srcAccess->opInst, &srcLoopIVs);
unsigned numSrcLoopIVs = srcLoopIVs.size();
if (srcLoopDepth > numSrcLoopIVs) {
srcAccess->opInst->emitError("invalid source loop depth");
return nullptr;
}
// Find the inst block positions of 'srcAccess->opInst' within 'srcLoopNest'.
// Get loop nest surrounding dst operation.
SmallVector<ForInst *, 4> dstLoopIVs;
getLoopIVs(*dstAccess->opInst, &dstLoopIVs);
unsigned dstLoopIVsSize = dstLoopIVs.size();
if (dstLoopDepth > dstLoopIVsSize) {
dstAccess->opInst->emitError("invalid destination loop depth");
return nullptr;
}
// Project out dimensions other than those up to src/dstLoopDepth's.
dependenceConstraints.projectOut(srcLoopDepth, numSrcLoopIVs - srcLoopDepth);
dependenceConstraints.projectOut(srcLoopDepth + dstLoopDepth,
dstLoopIVsSize - dstLoopDepth);
// Set up lower/upper bound affine maps for the slice.
SmallVector<AffineMap, 4> sliceLbs(srcLoopDepth, AffineMap::Null());
SmallVector<AffineMap, 4> sliceUbs(srcLoopDepth, AffineMap::Null());
// Get bounds for src IVs in terms of dst IVs, symbols, and constants.
dependenceConstraints.getSliceBounds(std::min(srcLoopDepth, numSrcLoopIVs),
srcAccess->opInst->getContext(),
&sliceLbs, &sliceUbs);
// Set up bound operands for the slice's lower and upper bounds.
SmallVector<Value *, 4> sliceBoundOperands;
dependenceConstraints.getIdValues(
srcLoopDepth, dependenceConstraints.getNumDimAndSymbolIds(),
&sliceBoundOperands);
// Find the inst block positions of 'srcAccess->opInst' within 'srcLoopIVs'.
SmallVector<unsigned, 4> positions;
findInstPosition(srcAccess->opInst, srcLoopNest[0]->getBlock(), &positions);
// TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
findInstPosition(srcAccess->opInst, srcLoopIVs[0]->getBlock(), &positions);
// Clone src loop nest and insert it a the beginning of the instruction block
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
auto *dstForInst = dstLoopNest[dstLoopDepth - 1];
// of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
auto *dstForInst = dstLoopIVs[dstLoopDepth - 1];
FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin());
DenseMap<const Value *, Value *> operandMap;
auto *sliceLoopNest = cast<ForInst>(b.clone(*srcLoopNest[0], operandMap));
auto *sliceLoopNest = cast<ForInst>(b.clone(*srcLoopIVs[0], operandMap));
// Lookup inst in cloned 'sliceLoopNest' at 'positions'.
Instruction *sliceInst =
getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
// Get loop nest surrounding 'sliceInst'.
SmallVector<ForInst *, 4> sliceSurroundingLoops;
getLoopIVs(*sliceInst, &sliceSurroundingLoops);
// Sanity check.
unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
(void)sliceSurroundingLoopsSize;
unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
// Update loop bounds for loops in 'sliceLoopNest'.
unsigned sliceLoopLimit = dstLoopDepth + srcLoopNestSize;
assert(sliceLoopLimit <= sliceSurroundingLoopsSize);
for (unsigned i = dstLoopDepth; i < sliceLoopLimit; ++i) {
auto *forInst = sliceSurroundingLoops[i];
unsigned index = i - dstLoopDepth;
AffineMap lbMap = srcIvMaps[index];
if (lbMap == AffineMap::Null())
continue;
forInst->setLowerBound(srcIvOperands[index], lbMap);
// Create upper bound map with is lower bound map + 1;
assert(lbMap.getNumResults() == 1);
AffineExpr ubResultExpr = lbMap.getResult(0) + 1;
AffineMap ubMap = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
{ubResultExpr}, {});
forInst->setUpperBound(srcIvOperands[index], ubMap);
for (unsigned i = 0; i < srcLoopDepth; ++i) {
auto *forInst = sliceSurroundingLoops[dstLoopDepth + i];
if (AffineMap lbMap = sliceLbs[i])
forInst->setLowerBound(sliceBoundOperands, lbMap);
if (AffineMap ubMap = sliceUbs[i])
forInst->setUpperBound(sliceBoundOperands, ubMap);
}
return sliceLoopNest;
}

View File

@@ -1409,6 +1409,10 @@ void IntegerSet::dump() const {
}
void AffineExpr::print(raw_ostream &os) const {
if (expr == nullptr) {
os << "null affine expr";
return;
}
ModuleState state(getContext());
ModulePrinter(os, state).printAffineExpr(*this);
}
@@ -1419,6 +1423,10 @@ void AffineExpr::dump() const {
}
void AffineMap::print(raw_ostream &os) const {
if (map == nullptr) {
os << "null affine map";
return;
}
ModuleState state(getContext());
ModulePrinter(os, state).printAffineMap(*this);
}

View File

@@ -47,12 +47,12 @@ using namespace mlir;
// depth per-loop nest, or depth per load/store op) for this pass utilizing a
// cost model.
static llvm::cl::opt<unsigned> clSrcLoopDepth(
"src-loop-depth", llvm::cl::Hidden,
"fusion-src-loop-depth", llvm::cl::Hidden,
llvm::cl::desc("Controls the depth of the source loop nest at which "
"to apply loop iteration slicing before fusion."));
static llvm::cl::opt<unsigned> clDstLoopDepth(
"dst-loop-depth", llvm::cl::Hidden,
"fusion-dst-loop-depth", llvm::cl::Hidden,
llvm::cl::desc("Controls the depth of the destination loop nest at which "
"to fuse the source loop nest slice."));

View File

@@ -117,6 +117,8 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) {
storeOps.push_back(storeOpInst);
}
unsigned loadOpDepth = getNestingDepth(*loadOpInst);
// 1. Check if there is a dependence satisfied at depth equal to the depth
// of the loop body of the innermost common surrounding loop of the storeOp
// and loadOp.
@@ -165,11 +167,16 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) {
// a [i] = ...
// for (j ...)
// ... = a[j]
MemRefRegion region;
getMemRefRegion(loadOpInst, nsLoops, &region);
if (!region.getConstraints()->isRangeOneToOne(
/*start=*/0, /*limit=*/loadOp->getMemRefType().getRank()))
break;
// If storeOpInst and loadOpDepth at the same nesting depth, the load Op
// is trivially loading from a single location at that depth; so there
// isn't a need to call isRangeOneToOne.
if (getNestingDepth(*storeOpInst) < loadOpDepth) {
MemRefRegion region;
getMemRefRegion(loadOpInst, nsLoops, &region);
if (!region.getConstraints()->isRangeOneToOne(
/*start=*/0, /*limit=*/loadOp->getMemRefType().getRank()))
break;
}
// After all these conditions, we have a candidate for forwarding!
fwdingCandidates.push_back(storeOpInst);

View File

@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -loop-fusion -split-input-file -verify | FileCheck %s
// RUN: mlir-opt %s -loop-fusion -src-loop-depth=1 -dst-loop-depth=1 -split-input-file -verify | FileCheck %s --check-prefix DEPTH1
// RUN: mlir-opt %s -loop-fusion -fusion-src-loop-depth=1 -fusion-dst-loop-depth=1 -split-input-file -verify | FileCheck %s --check-prefix DEPTH1
// TODO(andydavis) Add more tests:
// *) Add nested fusion test cases when non-constant loop bound support is
@@ -76,7 +76,8 @@ func @should_fuse_reduction_to_pointwise() {
// -----
// CHECK: [[MAP_SHIFT_MINUS_ONE:#map[0-9]+]] = (d0) -> (d0 - 1)
// CHECK: [[MAP_SHIFT_MINUS_ONE_D0:#map[0-9]+]] = (d0, d1) -> (d0 - 1)
// CHECK: [[MAP_SHIFT_MINUS_ONE_D1:#map[0-9]+]] = (d0, d1) -> (d1 - 1)
// CHECK: [[MAP_SHIFT_BY_ONE:#map[0-9]+]] = (d0, d1) -> (d0 + 1, d1 + 1)
// CHECK-LABEL: func @should_fuse_loop_nests_with_shifts() {
@@ -98,8 +99,8 @@ func @should_fuse_loop_nests_with_shifts() {
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: for %i1 = 0 to 10 {
// CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i0)
// CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE]](%i1)
// CHECK-NEXT: %1 = affine_apply [[MAP_SHIFT_MINUS_ONE_D0]](%i0, %i1)
// CHECK-NEXT: %2 = affine_apply [[MAP_SHIFT_MINUS_ONE_D1]](%i0, %i1)
// CHECK-NEXT: %3 = affine_apply [[MAP_SHIFT_BY_ONE]](%1, %2)
// CHECK-NEXT: store %cst, %0[%3#0, %3#1] : memref<10x10xf32>
// CHECK-NEXT: %4 = load %0[%i0, %i1] : memref<10x10xf32>
@@ -111,7 +112,8 @@ func @should_fuse_loop_nests_with_shifts() {
// -----
// CHECK: [[MAP_IDENTITY:#map[0-9]+]] = (d0) -> (d0)
// CHECK-DAG: [[MAP_DIM0:#map[0-9]+]] = (d0, d1) -> (d0)
// CHECK-DAG: [[MAP_DIM1:#map[0-9]+]] = (d0, d1) -> (d1)
// CHECK-LABEL: func @should_fuse_loop_nest() {
func @should_fuse_loop_nest() {
@@ -138,11 +140,11 @@ func @should_fuse_loop_nest() {
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: for %i1 = 0 to 10 {
// CHECK-NEXT: %2 = affine_apply [[MAP_IDENTITY]](%i1)
// CHECK-NEXT: %3 = affine_apply [[MAP_IDENTITY]](%i0)
// CHECK-NEXT: %2 = affine_apply [[MAP_DIM1]](%i0, %i1)
// CHECK-NEXT: %3 = affine_apply [[MAP_DIM0]](%i0, %i1)
// CHECK-NEXT: store %cst, %0[%2, %3] : memref<10x10xf32>
// CHECK-NEXT: %4 = affine_apply [[MAP_IDENTITY]](%i0)
// CHECK-NEXT: %5 = affine_apply [[MAP_IDENTITY]](%i1)
// CHECK-NEXT: %4 = affine_apply [[MAP_DIM0]](%i0, %i1)
// CHECK-NEXT: %5 = affine_apply [[MAP_DIM1]](%i0, %i1)
// CHECK-NEXT: %6 = load %0[%5, %4] : memref<10x10xf32>
// CHECK-NEXT: store %6, %1[%4, %5] : memref<10x10xf32>
// CHECK-NEXT: %7 = load %1[%i0, %i1] : memref<10x10xf32>
@@ -509,9 +511,11 @@ func @should_not_fuse_if_inst_in_loop_nest() {
// -----
// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0)
// CHECK: [[MAP1:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2)
// CHECK: [[MAP2:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2, d0)
// CHECK-DAG: [[MAP_D0:#map[0-9]+]] = (d0, d1, d2) -> (d0)
// CHECK-DAG: [[MAP_D1:#map[0-9]+]] = (d0, d1, d2) -> (d1)
// CHECK-DAG: [[MAP_D2:#map[0-9]+]] = (d0, d1, d2) -> (d2)
// CHECK: [[MAP_IDENTITY:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2)
// CHECK: [[MAP_PERMUTE:#map[0-9]+]] = (d0, d1, d2) -> (d1, d2, d0)
// CHECK-LABEL: func @remap_ivs() {
func @remap_ivs() {
@@ -537,12 +541,12 @@ func @remap_ivs() {
// CHECK: for %i0 = 0 to 30 {
// CHECK-NEXT: for %i1 = 0 to 10 {
// CHECK-NEXT: for %i2 = 0 to 20 {
// CHECK-NEXT: %1 = affine_apply [[MAP0]](%i1)
// CHECK-NEXT: %2 = affine_apply [[MAP0]](%i2)
// CHECK-NEXT: %3 = affine_apply [[MAP0]](%i0)
// CHECK-NEXT: %4 = affine_apply [[MAP1]](%1, %2, %3)
// CHECK-NEXT: %1 = affine_apply [[MAP_D1]](%i0, %i1, %i2)
// CHECK-NEXT: %2 = affine_apply [[MAP_D2]](%i0, %i1, %i2)
// CHECK-NEXT: %3 = affine_apply [[MAP_D0]](%i0, %i1, %i2)
// CHECK-NEXT: %4 = affine_apply [[MAP_IDENTITY]](%1, %2, %3)
// CHECK-NEXT: store %cst, %0[%4#0, %4#1, %4#2] : memref<10x20x30xf32>
// CHECK-NEXT: %5 = affine_apply [[MAP2]](%i0, %i1, %i2)
// CHECK-NEXT: %5 = affine_apply [[MAP_PERMUTE]](%i0, %i1, %i2)
// CHECK-NEXT: %6 = load %0[%5#0, %5#1, %5#2] : memref<10x20x30xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -627,3 +631,141 @@ func @fuse_reshape_64_16_4(%in : memref<64xf32>) {
// CHECK-NEXT: }
// CHECK-NEXT: return
}
// -----
// CHECK: #map0 = (d0) -> (d0 floordiv 4)
// CHECK: #map1 = (d0) -> (d0 mod 4)
// Reshape a 16x4xf32 to 64xf32.
// CHECK-LABEL: func @fuse_reshape_16_4_64
func @fuse_reshape_16_4_64() {
%in = alloc() : memref<16x4xf32>
%out = alloc() : memref<64xf32>
for %i0 = 0 to 16 {
for %i1 = 0 to 4 {
%v = load %in[%i0, %i1] : memref<16x4xf32>
%idx = affine_apply (d0, d1) -> (4*d0 + d1) (%i0, %i1)
store %v, %out[%idx] : memref<64xf32>
}
}
for %i2 = 0 to 64 {
%w = load %out[%i2] : memref<64xf32>
"foo"(%w) : (f32) -> ()
}
// CHECK: for %i0 = 0 to 64 {
// CHECK-NEXT: %2 = affine_apply #map0(%i0)
// CHECK-NEXT: %3 = affine_apply #map1(%i0)
// CHECK-NEXT: %4 = load %0[%2, %3] : memref<16x4xf32>
// CHECK-NEXT: %5 = affine_apply #map2(%2, %3)
// CHECK-NEXT: store %4, %1[%5] : memref<64xf32>
// CHECK-NEXT: %6 = load %1[%i0] : memref<64xf32>
// CHECK-NEXT: "foo"(%6) : (f32) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
return
}
// -----
// All three loop nests below (6-d one, 2-d one, 2-d one is fused into a single
// 2-d loop nest).
// CHECK-LABEL: func @R6_to_R2_reshape
func @R6_to_R2_reshape_square() -> memref<64x9xi32> {
%in = alloc() : memref<2x2x3x3x16x1xi32>
%out = alloc() : memref<64x9xi32>
// Initialize input with a different value for each 8x128 chunk.
for %i0 = 0 to 2 {
for %i1 = 0 to 2 {
for %i2 = 0 to 3 {
for %i3 = 0 to 3 {
for %i4 = 0 to 16 {
for %i5 = 0 to 1 {
%val = "foo"(%i0, %i1, %i2, %i3, %i4, %i5) : (index, index, index, index, index, index) -> i32
store %val, %in[%i0, %i1, %i2, %i3, %i4, %i5] : memref<2x2x3x3x16x1xi32>
}
}
}
}
}
}
for %ii = 0 to 64 {
for %jj = 0 to 9 {
// Convert output coordinates to linear index.
%a0 = affine_apply (d0, d1) -> (d0 * 9 + d1) (%ii, %jj)
%a1 = affine_apply (d0) -> (
d0 floordiv (2 * 3 * 3 * 16 * 1),
(d0 mod 288) floordiv (3 * 3 * 16 * 1),
((d0 mod 288) mod 144) floordiv 48,
(((d0 mod 288) mod 144) mod 48) floordiv 16,
((((d0 mod 288) mod 144) mod 48) mod 16),
(((d0 mod 144) mod 144) mod 48) mod 16
) (%a0)
%v = load %in[%a1#0, %a1#1, %a1#3, %a1#4, %a1#2, %a1#5]
: memref<2x2x3x3x16x1xi32>
store %v, %out[%ii, %jj] : memref<64x9xi32>
}
}
for %i = 0 to 64 {
for %j = 0 to 9 {
%a = load %out[%i, %j] : memref<64x9xi32>
%b = muli %a, %a : i32
store %b, %out[%i, %j] : memref<64x9xi32>
}
}
return %out : memref<64x9xi32>
}
// Everything above is fused to a single 2-d loop nest, and the 6-d tensor %in
// is eliminated if -memref-dataflow-opt is also supplied.
//
// CHECK: for %i0 = 0 to 64 {
// CHECK-NEXT: for %i1 = 0 to 9 {
// CHECK-NEXT: %2 = affine_apply #map0(%i0, %i1)
// CHECK-NEXT: %3 = affine_apply #map1(%i0, %i1)
// CHECK-NEXT: %4 = affine_apply #map2(%i0, %i1)
// CHECK-NEXT: %5 = affine_apply #map3(%i0, %i1)
// CHECK-NEXT: %6 = affine_apply #map4(%i0, %i1)
// CHECK-NEXT: %7 = "foo"(%2, %3, %4, %5, %6, %c0) : (index, index, index, index, index, index) -> i32
// CHECK-NEXT: store %7, %0[%2, %3, %4, %5, %6, %c0] : memref<2x2x3x3x16x1xi32>
// CHECK-NEXT: %8 = affine_apply #map5(%i0, %i1)
// CHECK-NEXT: %9 = affine_apply #map6(%i0, %i1)
// CHECK-NEXT: %10 = affine_apply #map7(%8, %9)
// CHECK-NEXT: %11 = affine_apply #map8(%10)
// CHECK-NEXT: %12 = load %0[%11#0, %11#1, %11#3, %11#4, %11#2, %11#5] : memref<2x2x3x3x16x1xi32>
// CHECK-NEXT: store %12, %1[%8, %9] : memref<64x9xi32>
// CHECK-NEXT: %13 = load %1[%i0, %i1] : memref<64x9xi32>
// CHECK-NEXT: %14 = muli %13, %13 : i32
// CHECK-NEXT: store %14, %1[%i0, %i1] : memref<64x9xi32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return %1 : memref<64x9xi32>
// -----
// CHECK-LABEL: func @fuse_symbolic_bounds
func @fuse_symbolic_bounds(%M : index, %N : index) {
%m = alloc() : memref<800x800xf32>
%c0 = constant 0.0 : f32
%s = constant 5 : index
for %i0 = 0 to %M {
for %i1 = 0 to (d0) -> (d0 + 5) (%N) {
store %c0, %m[%i0, %i1] : memref<800 x 800 x f32>
}
}
for %i2 = 0 to %M {
for %i3 = 0 to %N {
%idx = affine_apply (d0, d1)[s0] -> (d0, d1 + s0) (%i2, %i3)[%s]
%v = load %m[%idx#0, %idx#1] : memref<800 x 800 x f32>
}
}
return
}

View File

@@ -167,7 +167,7 @@ func @multi_store_load_nested_fwd(%N : index) {
return
}
// No one-to-one dependence here between the store and load.
// There is no unique load location for the store to forward to.
// CHECK-LABEL: func @store_load_no_fwd
func @store_load_no_fwd() {
%cf7 = constant 7.0 : f32