mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[MLIR] Generalize detecting mods during slice computing
During slice computation of affine loop fusion, detect one id as the mod of another id w.r.t a constant in a more generic way. Restrictions on co-efficients of the ids is removed. Also, information from the previously calculated ids is used for simplification of affine expressions, e.g., If `id1` = `id2`, `id_n - divisor * id_q - id_r + id1 - id2 = 0`, is simplified to: `id_n - divisor * id_q - id_r = 0`. If `c` is a non-zero integer, `c*id_n - c*divisor * id_q - c*id_r = 0`, is simplified to: `id_n - divisor * id_q - id_r = 0`. Reviewed By: bondhugula, ayzhuang Differential Revision: https://reviews.llvm.org/D104614
This commit is contained in:
committed by
Uday Bondhugula
parent
0e55112242
commit
a873b6d466
@@ -1430,88 +1430,123 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
|
||||
return posLimit - posStart;
|
||||
}
|
||||
|
||||
// Detect the identifier at 'pos' (say id_r) as modulo of another identifier
|
||||
// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
|
||||
// could be detected as the floordiv of n. For eg:
|
||||
// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=>
|
||||
// id_r = id_n mod 4, id_q = id_n floordiv 4.
|
||||
// lbConst and ubConst are the constant lower and upper bounds for 'pos' -
|
||||
// pre-detected at the caller.
|
||||
// Determine whether the identifier at 'pos' (say id_r) can be expressed as
|
||||
// modulo of another known identifier (say id_n) w.r.t a constant. For example,
|
||||
// if the following constraints hold true:
|
||||
// ```
|
||||
// 0 <= id_r <= divisor - 1
|
||||
// id_n - (divisor * q_expr) = id_r
|
||||
// ```
|
||||
// where `id_n` is a known identifier (called dividend), and `q_expr` is an
|
||||
// `AffineExpr` (called the quotient expression), `id_r` can be written as:
|
||||
//
|
||||
// `id_r = id_n mod divisor`.
|
||||
//
|
||||
// Additionally, in a special case of the above constaints where `q_expr` is an
|
||||
// identifier itself that is not yet known (say `id_q`), it can be written as a
|
||||
// floordiv in the following way:
|
||||
//
|
||||
// `id_q = id_n floordiv divisor`.
|
||||
//
|
||||
// Returns true if the above mod or floordiv are detected, updating 'memo' with
|
||||
// these new expressions. Returns false otherwise.
|
||||
static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
|
||||
int64_t lbConst, int64_t ubConst,
|
||||
SmallVectorImpl<AffineExpr> *memo) {
|
||||
SmallVectorImpl<AffineExpr> &memo,
|
||||
MLIRContext *context) {
|
||||
assert(pos < cst.getNumIds() && "invalid position");
|
||||
|
||||
// Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
|
||||
// id_n - divisor * id_q. If these are true, then id_n becomes the dividend
|
||||
// and id_q the quotient when dividing id_n by the divisor.
|
||||
|
||||
// Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can
|
||||
// be determined.
|
||||
if (lbConst != 0 || ubConst < 1)
|
||||
return false;
|
||||
|
||||
int64_t divisor = ubConst + 1;
|
||||
|
||||
// Now check for: id_r = id_n - divisor * id_q. As an example, we
|
||||
// are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
|
||||
unsigned seenQuotient = 0, seenDividend = 0;
|
||||
int quotientPos = -1, dividendPos = -1;
|
||||
for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
|
||||
// id_n should have coeff 1 or -1.
|
||||
if (std::abs(cst.atEq(r, pos)) != 1)
|
||||
// Check for the aforementioned conditions in each equality.
|
||||
for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
|
||||
curEquality < numEqualities; curEquality++) {
|
||||
int64_t coefficientAtPos = cst.atEq(curEquality, pos);
|
||||
// If current equality does not involve `id_r`, continue to the next
|
||||
// equality.
|
||||
if (coefficientAtPos == 0)
|
||||
continue;
|
||||
// constant term should be 0.
|
||||
if (cst.atEq(r, cst.getNumCols() - 1) != 0)
|
||||
|
||||
// Constant term should be 0 in this equality.
|
||||
if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
|
||||
continue;
|
||||
unsigned c, f;
|
||||
int quotientSign = 1, dividendSign = 1;
|
||||
for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
|
||||
if (c == pos)
|
||||
|
||||
// Traverse through the equality and construct the dividend expression
|
||||
// `dividendExpr`, to contain all the identifiers which are known and are
|
||||
// not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
|
||||
// `dividendExpr` gets simplified into a single identifier `id_n` discussed
|
||||
// above.
|
||||
auto dividendExpr = getAffineConstantExpr(0, context);
|
||||
|
||||
// Track the terms that go into quotient expression, later used to detect
|
||||
// additional floordiv.
|
||||
unsigned quotientCount = 0;
|
||||
int quotientPosition = -1;
|
||||
int quotientSign = 1;
|
||||
|
||||
// Consider each term in the current equality.
|
||||
unsigned curId, e;
|
||||
for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) {
|
||||
// Ignore id_r.
|
||||
if (curId == pos)
|
||||
continue;
|
||||
int64_t coefficientOfCurId = cst.atEq(curEquality, curId);
|
||||
// Ignore ids that do not contribute to the current equality.
|
||||
if (coefficientOfCurId == 0)
|
||||
continue;
|
||||
// Check if the current id goes into the quotient expression.
|
||||
if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) {
|
||||
quotientCount++;
|
||||
quotientPosition = curId;
|
||||
quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1;
|
||||
continue;
|
||||
// The coefficient of the quotient should be +/-divisor.
|
||||
// TODO: could be extended to detect an affine function for the quotient
|
||||
// (i.e., the coeff could be a non-zero multiple of divisor).
|
||||
int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
|
||||
if (v == divisor || v == -divisor) {
|
||||
seenQuotient++;
|
||||
quotientPos = c;
|
||||
quotientSign = v > 0 ? 1 : -1;
|
||||
}
|
||||
// The coefficient of the dividend should be +/-1.
|
||||
// TODO: could be extended to detect an affine function of the other
|
||||
// identifiers as the dividend.
|
||||
else if (v == -1 || v == 1) {
|
||||
seenDividend++;
|
||||
dividendPos = c;
|
||||
dividendSign = v < 0 ? 1 : -1;
|
||||
} else if (cst.atEq(r, c) != 0) {
|
||||
// Cannot be inferred as a mod since the constraint has a coefficient
|
||||
// for an identifier that's neither a unit nor the divisor (see TODOs
|
||||
// above).
|
||||
// Identifiers that are part of dividendExpr should be known.
|
||||
if (!memo[curId])
|
||||
break;
|
||||
}
|
||||
// Append the current identifier to the dividend expression.
|
||||
dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId;
|
||||
}
|
||||
if (c < f)
|
||||
// Cannot be inferred as a mod since the constraint has a coefficient for
|
||||
// an identifier that's neither a unit nor the divisor (see TODOs above).
|
||||
|
||||
// Can't construct expression as it depends on a yet uncomputed id.
|
||||
if (curId < e)
|
||||
continue;
|
||||
|
||||
// We are looking for exactly one identifier as the dividend.
|
||||
if (seenDividend == 1 && seenQuotient >= 1) {
|
||||
if (!(*memo)[dividendPos])
|
||||
return false;
|
||||
// Successfully detected a mod.
|
||||
(*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
|
||||
auto ub = cst.getConstantUpperBound(dividendPos);
|
||||
if (ub.hasValue() && ub.getValue() < divisor)
|
||||
// The mod can be optimized away.
|
||||
(*memo)[pos] = (*memo)[dividendPos] * dividendSign;
|
||||
else
|
||||
(*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
|
||||
// Express `id_r` in terms of the other ids collected so far.
|
||||
if (coefficientAtPos > 0)
|
||||
dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
|
||||
else
|
||||
dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
|
||||
|
||||
// Simplify the expression.
|
||||
dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(),
|
||||
cst.getNumSymbolIds());
|
||||
// Only if the final dividend expression is just a single id (which we call
|
||||
// `id_n`), we can proceed.
|
||||
// TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
|
||||
// to dims themselves.
|
||||
auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
|
||||
if (!dimExpr)
|
||||
continue;
|
||||
|
||||
// Express `id_r` as `id_n % divisor` and store the expression in `memo`.
|
||||
if (quotientCount >= 1) {
|
||||
auto ub = cst.getConstantUpperBound(dimExpr.getPosition());
|
||||
// If `id_n` has an upperbound that is less than the divisor, mod can be
|
||||
// eliminated altogether.
|
||||
if (ub.hasValue() && ub.getValue() < divisor)
|
||||
memo[pos] = dimExpr;
|
||||
else
|
||||
memo[pos] = dimExpr % divisor;
|
||||
// If a unique quotient `id_q` was seen, it can be expressed as
|
||||
// `id_n floordiv divisor`.
|
||||
if (quotientCount == 1 && !memo[quotientPosition])
|
||||
memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
|
||||
|
||||
if (seenQuotient == 1 && !(*memo)[quotientPos])
|
||||
// Successfully detected a floordiv as well.
|
||||
(*memo)[quotientPos] =
|
||||
(*memo)[dividendPos].floorDiv(divisor) * quotientSign;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -1885,7 +1920,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
|
||||
// Detect an identifier as modulo of another identifier w.r.t a
|
||||
// constant.
|
||||
if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
|
||||
&memo)) {
|
||||
memo, context)) {
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -3330,3 +3330,69 @@ func @fuse_large_number_of_loops(%arg0: memref<20x10xf32, 1>, %arg1: memref<20x1
|
||||
// CHECK: affine.for
|
||||
// CHECK: affine.for
|
||||
// CHECK-NOT: affine.for
|
||||
|
||||
// -----
|
||||
|
||||
// Expects fusion of producer into consumer at depth 4 and subsequent removal of
|
||||
// source loop.
|
||||
// CHECK-LABEL: func @unflatten4d
|
||||
func @unflatten4d(%arg1: memref<7x8x9x10xf32>) {
|
||||
%m = memref.alloc() : memref<5040xf32>
|
||||
%cf7 = constant 7.0 : f32
|
||||
|
||||
affine.for %i0 = 0 to 7 {
|
||||
affine.for %i1 = 0 to 8 {
|
||||
affine.for %i2 = 0 to 9 {
|
||||
affine.for %i3 = 0 to 10 {
|
||||
affine.store %cf7, %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
affine.for %i0 = 0 to 7 {
|
||||
affine.for %i1 = 0 to 8 {
|
||||
affine.for %i2 = 0 to 9 {
|
||||
affine.for %i3 = 0 to 10 {
|
||||
%v0 = affine.load %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32>
|
||||
affine.store %v0, %arg1[%i0, %i1, %i2, %i3] : memref<7x8x9x10xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NOT: affine.for
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
// Expects fusion of producer into consumer at depth 2 and subsequent removal of
|
||||
// source loop.
|
||||
// CHECK-LABEL: func @unflatten2d_with_transpose
|
||||
func @unflatten2d_with_transpose(%arg1: memref<8x7xf32>) {
|
||||
%m = memref.alloc() : memref<56xf32>
|
||||
%cf7 = constant 7.0 : f32
|
||||
|
||||
affine.for %i0 = 0 to 7 {
|
||||
affine.for %i1 = 0 to 8 {
|
||||
affine.store %cf7, %m[8 * %i0 + %i1] : memref<56xf32>
|
||||
}
|
||||
}
|
||||
affine.for %i0 = 0 to 8 {
|
||||
affine.for %i1 = 0 to 7 {
|
||||
%v0 = affine.load %m[%i0 + 8 * %i1] : memref<56xf32>
|
||||
affine.store %v0, %arg1[%i0, %i1] : memref<8x7xf32>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: affine.for
|
||||
// CHECK-NEXT: affine.for
|
||||
// CHECK-NOT: affine.for
|
||||
// CHECK: return
|
||||
|
||||
Reference in New Issue
Block a user