[MLIR] Generalize detecting mods during slice computing

During slice computation of affine loop fusion, detect one id as the mod
of another id w.r.t a constant in a more generic way. Restrictions on
co-efficients of the ids is removed. Also, information from the
previously calculated ids is used for simplification of affine
expressions, e.g.,

If `id1` = `id2`,
  `id_n - divisor * id_q - id_r + id1 - id2 = 0`, is simplified to:
  `id_n - divisor * id_q - id_r = 0`.

If `c` is a non-zero integer,
  `c*id_n - c*divisor * id_q - c*id_r = 0`, is simplified to:
  `id_n - divisor * id_q - id_r = 0`.

Reviewed By: bondhugula, ayzhuang

Differential Revision: https://reviews.llvm.org/D104614
This commit is contained in:
Vinayaka Bandishti
2021-06-23 12:25:09 +05:30
committed by Uday Bondhugula
parent 0e55112242
commit a873b6d466
2 changed files with 167 additions and 66 deletions

View File

@@ -1430,88 +1430,123 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
return posLimit - posStart;
}
// Detect the identifier at 'pos' (say id_r) as modulo of another identifier
// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
// could be detected as the floordiv of n. For eg:
// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=>
// id_r = id_n mod 4, id_q = id_n floordiv 4.
// lbConst and ubConst are the constant lower and upper bounds for 'pos' -
// pre-detected at the caller.
// Determine whether the identifier at 'pos' (say id_r) can be expressed as
// modulo of another known identifier (say id_n) w.r.t a constant. For example,
// if the following constraints hold true:
// ```
// 0 <= id_r <= divisor - 1
// id_n - (divisor * q_expr) = id_r
// ```
// where `id_n` is a known identifier (called dividend), and `q_expr` is an
// `AffineExpr` (called the quotient expression), `id_r` can be written as:
//
// `id_r = id_n mod divisor`.
//
// Additionally, in a special case of the above constaints where `q_expr` is an
// identifier itself that is not yet known (say `id_q`), it can be written as a
// floordiv in the following way:
//
// `id_q = id_n floordiv divisor`.
//
// Returns true if the above mod or floordiv are detected, updating 'memo' with
// these new expressions. Returns false otherwise.
static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
int64_t lbConst, int64_t ubConst,
SmallVectorImpl<AffineExpr> *memo) {
SmallVectorImpl<AffineExpr> &memo,
MLIRContext *context) {
assert(pos < cst.getNumIds() && "invalid position");
// Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
// id_n - divisor * id_q. If these are true, then id_n becomes the dividend
// and id_q the quotient when dividing id_n by the divisor.
// Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can
// be determined.
if (lbConst != 0 || ubConst < 1)
return false;
int64_t divisor = ubConst + 1;
// Now check for: id_r = id_n - divisor * id_q. As an example, we
// are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
unsigned seenQuotient = 0, seenDividend = 0;
int quotientPos = -1, dividendPos = -1;
for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
// id_n should have coeff 1 or -1.
if (std::abs(cst.atEq(r, pos)) != 1)
// Check for the aforementioned conditions in each equality.
for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
curEquality < numEqualities; curEquality++) {
int64_t coefficientAtPos = cst.atEq(curEquality, pos);
// If current equality does not involve `id_r`, continue to the next
// equality.
if (coefficientAtPos == 0)
continue;
// constant term should be 0.
if (cst.atEq(r, cst.getNumCols() - 1) != 0)
// Constant term should be 0 in this equality.
if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
continue;
unsigned c, f;
int quotientSign = 1, dividendSign = 1;
for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
if (c == pos)
// Traverse through the equality and construct the dividend expression
// `dividendExpr`, to contain all the identifiers which are known and are
// not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
// `dividendExpr` gets simplified into a single identifier `id_n` discussed
// above.
auto dividendExpr = getAffineConstantExpr(0, context);
// Track the terms that go into quotient expression, later used to detect
// additional floordiv.
unsigned quotientCount = 0;
int quotientPosition = -1;
int quotientSign = 1;
// Consider each term in the current equality.
unsigned curId, e;
for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) {
// Ignore id_r.
if (curId == pos)
continue;
int64_t coefficientOfCurId = cst.atEq(curEquality, curId);
// Ignore ids that do not contribute to the current equality.
if (coefficientOfCurId == 0)
continue;
// Check if the current id goes into the quotient expression.
if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) {
quotientCount++;
quotientPosition = curId;
quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1;
continue;
// The coefficient of the quotient should be +/-divisor.
// TODO: could be extended to detect an affine function for the quotient
// (i.e., the coeff could be a non-zero multiple of divisor).
int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
if (v == divisor || v == -divisor) {
seenQuotient++;
quotientPos = c;
quotientSign = v > 0 ? 1 : -1;
}
// The coefficient of the dividend should be +/-1.
// TODO: could be extended to detect an affine function of the other
// identifiers as the dividend.
else if (v == -1 || v == 1) {
seenDividend++;
dividendPos = c;
dividendSign = v < 0 ? 1 : -1;
} else if (cst.atEq(r, c) != 0) {
// Cannot be inferred as a mod since the constraint has a coefficient
// for an identifier that's neither a unit nor the divisor (see TODOs
// above).
// Identifiers that are part of dividendExpr should be known.
if (!memo[curId])
break;
}
// Append the current identifier to the dividend expression.
dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId;
}
if (c < f)
// Cannot be inferred as a mod since the constraint has a coefficient for
// an identifier that's neither a unit nor the divisor (see TODOs above).
// Can't construct expression as it depends on a yet uncomputed id.
if (curId < e)
continue;
// We are looking for exactly one identifier as the dividend.
if (seenDividend == 1 && seenQuotient >= 1) {
if (!(*memo)[dividendPos])
return false;
// Successfully detected a mod.
(*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
auto ub = cst.getConstantUpperBound(dividendPos);
if (ub.hasValue() && ub.getValue() < divisor)
// The mod can be optimized away.
(*memo)[pos] = (*memo)[dividendPos] * dividendSign;
else
(*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
// Express `id_r` in terms of the other ids collected so far.
if (coefficientAtPos > 0)
dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
else
dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
// Simplify the expression.
dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(),
cst.getNumSymbolIds());
// Only if the final dividend expression is just a single id (which we call
// `id_n`), we can proceed.
// TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
// to dims themselves.
auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
if (!dimExpr)
continue;
// Express `id_r` as `id_n % divisor` and store the expression in `memo`.
if (quotientCount >= 1) {
auto ub = cst.getConstantUpperBound(dimExpr.getPosition());
// If `id_n` has an upperbound that is less than the divisor, mod can be
// eliminated altogether.
if (ub.hasValue() && ub.getValue() < divisor)
memo[pos] = dimExpr;
else
memo[pos] = dimExpr % divisor;
// If a unique quotient `id_q` was seen, it can be expressed as
// `id_n floordiv divisor`.
if (quotientCount == 1 && !memo[quotientPosition])
memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
if (seenQuotient == 1 && !(*memo)[quotientPos])
// Successfully detected a floordiv as well.
(*memo)[quotientPos] =
(*memo)[dividendPos].floorDiv(divisor) * quotientSign;
return true;
}
}
@@ -1885,7 +1920,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
// Detect an identifier as modulo of another identifier w.r.t a
// constant.
if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
&memo)) {
memo, context)) {
changed = true;
continue;
}

View File

@@ -3330,3 +3330,69 @@ func @fuse_large_number_of_loops(%arg0: memref<20x10xf32, 1>, %arg1: memref<20x1
// CHECK: affine.for
// CHECK: affine.for
// CHECK-NOT: affine.for
// -----
// Expects fusion of producer into consumer at depth 4 and subsequent removal of
// source loop.
// CHECK-LABEL: func @unflatten4d
func @unflatten4d(%arg1: memref<7x8x9x10xf32>) {
%m = memref.alloc() : memref<5040xf32>
%cf7 = constant 7.0 : f32
affine.for %i0 = 0 to 7 {
affine.for %i1 = 0 to 8 {
affine.for %i2 = 0 to 9 {
affine.for %i3 = 0 to 10 {
affine.store %cf7, %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32>
}
}
}
}
affine.for %i0 = 0 to 7 {
affine.for %i1 = 0 to 8 {
affine.for %i2 = 0 to 9 {
affine.for %i3 = 0 to 10 {
%v0 = affine.load %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32>
affine.store %v0, %arg1[%i0, %i1, %i2, %i3] : memref<7x8x9x10xf32>
}
}
}
}
return
}
// CHECK: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NOT: affine.for
// CHECK: return
// -----
// Expects fusion of producer into consumer at depth 2 and subsequent removal of
// source loop.
// CHECK-LABEL: func @unflatten2d_with_transpose
func @unflatten2d_with_transpose(%arg1: memref<8x7xf32>) {
%m = memref.alloc() : memref<56xf32>
%cf7 = constant 7.0 : f32
affine.for %i0 = 0 to 7 {
affine.for %i1 = 0 to 8 {
affine.store %cf7, %m[8 * %i0 + %i1] : memref<56xf32>
}
}
affine.for %i0 = 0 to 8 {
affine.for %i1 = 0 to 7 {
%v0 = affine.load %m[%i0 + 8 * %i1] : memref<56xf32>
affine.store %v0, %arg1[%i0, %i1] : memref<8x7xf32>
}
}
return
}
// CHECK: affine.for
// CHECK-NEXT: affine.for
// CHECK-NOT: affine.for
// CHECK: return