diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index 95cb61e16b79..ebd27e15b419 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -49,6 +49,14 @@ void getReachableAffineApplyOps( llvm::ArrayRef operands, llvm::SmallVectorImpl &affineApplyOps); +/// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false +/// if 'expr' was unable to be flattened (i.e. because it was not pure affine, +/// or because it contained mod's and div's that could not be eliminated +/// without introducing local variables). +bool getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + llvm::SmallVectorImpl *flattenedExpr); + } // end namespace mlir #endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 87a8e3e919bf..96add34ad0a6 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -248,6 +248,9 @@ public: explicit FlatAffineConstraints(const AffineValueMap &avm); explicit FlatAffineConstraints(ArrayRef avmRef); + /// Creates an affine constraint system from an IntegerSet. + explicit FlatAffineConstraints(IntegerSet set); + /// Create an affine constraint system from an IntegerValueSet. // TODO(bondhugula) explicit FlatAffineConstraints(const IntegerValueSet &set); @@ -259,6 +262,24 @@ public: ~FlatAffineConstraints() {} + // Checks for emptiness by performing variable elimination on all identifiers, + // running the GCD test on each equality constraint, and checking for invalid + // constraints. + // Returns true if the GCD test fails for any equality, or if any invalid + // constraints are discovered on any row. Returns false otherwise. + // TODO(andydavis) Change this method to operate on cloned constraints. + bool isEmpty(); + + // Eliminates a single identifier at 'position' from equality and inequality + // constraints. Returns 'true' if the identifier was eliminated. + // Returns 'false' otherwise. + bool eliminateIdentifier(unsigned position); + + // Eliminates identifiers from equality and inequality constraints + // in column range [posStart, posLimit). + // Returns the number of variables eliminated. + unsigned eliminateIdentifiers(unsigned posStart, unsigned posLimit); + inline int64_t atEq(unsigned i, unsigned j) const { return equalities[i * (numIds + 1) + j]; } @@ -267,6 +288,14 @@ public: return equalities[i * (numIds + 1) + j]; } + inline int64_t atEqIdx(unsigned linearIndex) const { + return equalities[linearIndex]; + } + + inline int64_t &atEqIdx(unsigned linearIndex) { + return equalities[linearIndex]; + } + inline int64_t atIneq(unsigned i, unsigned j) const { return inequalities[i * (numIds + 1) + j]; } @@ -275,6 +304,14 @@ public: return inequalities[i * (numIds + 1) + j]; } + inline int64_t atIneqIdx(unsigned linearIndex) const { + return inequalities[linearIndex]; + } + + inline int64_t &atIneqIdx(unsigned linearIndex) { + return inequalities[linearIndex]; + } + inline unsigned getNumCols() const { return numIds + 1; } inline unsigned getNumEqualities() const { @@ -323,6 +360,11 @@ public: void dump() const; private: + // Removes coefficients in column range [colStart, colLimit),and copies any + // remaining valid data into place, updates member variables, and resizes + // arrays as needed. + void removeColumnRange(unsigned colStart, unsigned colLimit); + /// Coefficients of affine equalities (in == 0 form). SmallVector equalities; diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h index fb00e88676de..7ab9db476a44 100644 --- a/mlir/include/mlir/IR/IntegerSet.h +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -59,6 +59,14 @@ public: ArrayRef constraints, ArrayRef eqFlags, MLIRContext *context); + // Returns a canonical empty IntegerSet (i.e. a set with no integer points). + static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols, + MLIRContext *context) { + auto one = getAffineConstantExpr(1, context); + /* 1 == 0 */ + return get(numDims, numSymbols, one, true, context); + } + explicit operator bool() { return set; } bool operator==(IntegerSet other) const { return set == other.set; } @@ -66,6 +74,8 @@ public: unsigned getNumSymbols() const; unsigned getNumOperands() const; unsigned getNumConstraints() const; + unsigned getNumEqualities() const; + unsigned getNumInequalities() const; ArrayRef getConstraints() const; @@ -79,6 +89,8 @@ public: /// inequality. bool isEq(unsigned idx) const; + MLIRContext *getContext() const; + void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index 5fa0126d8761..77a4327164ba 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -465,6 +465,10 @@ public: const AffineCondition getCondition() const; IntegerSet getIntegerSet() const { return set; } + void setIntegerSet(IntegerSet newSet) { + assert(newSet.getNumOperands() == operands.size()); + set = newSet; + } //===--------------------------------------------------------------------===// // Operands diff --git a/mlir/include/mlir/Support/MathExtras.h b/mlir/include/mlir/Support/MathExtras.h index 6e9835a27d15..fce4616898b3 100644 --- a/mlir/include/mlir/Support/MathExtras.h +++ b/mlir/include/mlir/Support/MathExtras.h @@ -51,6 +51,14 @@ inline int64_t mod(int64_t lhs, int64_t rhs) { return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs; } +/// Returns the least common multiple of 'a' and 'b'. +inline int64_t lcm(int64_t a, int64_t b) { + uint64_t x = std::abs(a); + uint64_t y = std::abs(b); + int64_t lcm = (x * y) / llvm::GreatestCommonDivisor64(x, y); + assert((lcm >= a && lcm >= b) && "LCM overflow"); + return lcm; +} } // end namespace mlir #endif // MLIR_SUPPORT_MATHEXTRAS_H_ diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index a16997c08d38..9b5cc99ff8e3 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -298,6 +298,30 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, return simplifiedExpr; } +// Flattens 'expr' into 'flattenedExpr'. Returns true on success or false +// if 'expr' was unable to be flattened (i.e. because it was not pur affine, +// or because it contained mod's and div's that could not be eliminated +// without introducing local variables). +bool mlir::getFlattenedAffineExpr( + AffineExpr expr, unsigned numDims, unsigned numSymbols, + llvm::SmallVectorImpl *flattenedExpr) { + // TODO(bondhugula): only pure affine for now. The simplification here can be + // extended to semi-affine maps in the future. + if (!expr.isPureAffine()) + return false; + + AffineExprFlattener flattener(numDims, numSymbols, expr.getContext()); + flattener.walkPostOrder(expr); + // TODO(andydavis) Support local exprs. + if (flattener.numLocals > 0) { + return false; + } + for (auto v : flattener.operandExprStack.back()) { + flattenedExpr->push_back(v); + } + return true; +} + /// Returns the sequence of AffineApplyOp OperationStmts operation in /// 'affineApplyOps', which are reachable via a search starting from 'operands', /// and ending at operands which are not defined by AffineApplyOps. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 173e8b50ab5a..e45b5ccc0646 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLValue.h" +#include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/raw_ostream.h" @@ -438,6 +439,286 @@ AffineMap AffineValueMap::getAffineMap() { return map.getAffineMap(); } AffineValueMap::~AffineValueMap() {} +FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) + : numReservedEqualities(0), numReservedInequalities(0), numReservedIds(0), + numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), + numSymbols(set.getNumSymbols()) { + unsigned numConstraints = set.getNumConstraints(); + for (unsigned i = 0; i < numConstraints; ++i) { + AffineExpr expr = set.getConstraint(i); + SmallVector flattenedExpr; + getFlattenedAffineExpr(expr, set.getNumDims(), set.getNumSymbols(), + &flattenedExpr); + assert(flattenedExpr.size() == getNumCols()); + if (set.getEqFlags()[i]) { + addEquality(flattenedExpr); + } else { + addInequality(flattenedExpr); + } + } +} + +// Searches for a constraint with a non-zero coefficient at 'colIdx' in +// equality (isEq=true) or inequality (isEq=false) constraints. +// Returns true and sets row found in search in 'rowIdx'. +// Returns false otherwise. +static bool +findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints, + unsigned colIdx, bool isEq, unsigned &rowIdx) { + auto at = [&](unsigned rowIdx) -> int64_t { + return isEq ? constraints.atEq(rowIdx, colIdx) + : constraints.atIneq(rowIdx, colIdx); + }; + unsigned e = + isEq ? constraints.getNumEqualities() : constraints.getNumInequalities(); + for (rowIdx = 0; rowIdx < e; ++rowIdx) { + if (at(rowIdx) != 0) { + return true; + } + } + return false; +} + +// Normalizes the coefficient values across all columns in 'rowIDx' by their +// GCD in equality or inequality contraints as specified by 'isEq'. +static void normalizeConstraintByGCD(FlatAffineConstraints *constraints, + unsigned rowIdx, bool isEq) { + auto at = [&](unsigned colIdx) -> int64_t { + return isEq ? constraints->atEq(rowIdx, colIdx) + : constraints->atIneq(rowIdx, colIdx); + }; + uint64_t gcd = std::abs(at(0)); + for (unsigned j = 1; j < constraints->getNumCols(); ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j))); + } + if (gcd > 0 && gcd != 1) { + for (unsigned j = 0; j < constraints->getNumCols(); ++j) { + int64_t v = at(j) / static_cast(gcd); + isEq ? constraints->atEq(rowIdx, j) = v + : constraints->atIneq(rowIdx, j) = v; + } + } +} + +// Runs the GCD test on all equality constraints. Returns 'true' if this test +// fails on any equality. Returns 'false' otherwise. +// This test can be used to disprove the existence of a solution. If it returns +// true, no integer solution to the equality constraints can exist. +// +// GCD test definition: +// +// The equality constraint: +// +// c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 +// +// has an integer solution iff: +// +// GCD of c_1, c_2, ..., c_n divides c_0. +// +static bool isEmptyByGCDTest(const FlatAffineConstraints &constraints) { + unsigned numCols = constraints.getNumCols(); + for (unsigned i = 0, e = constraints.getNumEqualities(); i < e; ++i) { + uint64_t gcd = std::abs(constraints.atEq(i, 0)); + for (unsigned j = 1; j < numCols - 1; ++j) { + gcd = + llvm::GreatestCommonDivisor64(gcd, std::abs(constraints.atEq(i, j))); + } + int64_t v = std::abs(constraints.atEq(i, numCols - 1)); + if (gcd > 0 && (v % gcd != 0)) { + return true; + } + } + return false; +} + +// Checks all rows of equality/inequality constraints for contradictions +// (i.e. 1 == 0), which may have surfaced after elimination. +// Returns 'true' if a valid constraint is detected. Returns 'false' otherwise. +static bool hasInvalidConstraint(const FlatAffineConstraints &constraints) { + auto check = [constraints](bool isEq) -> bool { + unsigned numCols = constraints.getNumCols(); + unsigned numRows = isEq ? constraints.getNumEqualities() + : constraints.getNumInequalities(); + for (unsigned i = 0, e = numRows; i < e; ++i) { + unsigned j; + for (j = 0; j < numCols - 1; ++j) { + int64_t v = isEq ? constraints.atEq(i, j) : constraints.atIneq(i, j); + // Skip rows with non-zero variable coefficients. + if (v != 0) + break; + } + if (j < numCols - 1) { + continue; + } + // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. + // Example invalid constraints include: '1 == 0' or '-1 >= 0' + int64_t v = isEq ? constraints.atEq(i, numCols - 1) + : constraints.atIneq(i, numCols - 1); + if ((isEq && v != 0) || (!isEq && v < 0)) { + return true; + } + } + return false; + }; + if (check(/*isEq=*/true)) + return true; + return check(/*isEq=*/false); +} + +// Eliminate identifier from constraint at 'rowIdx' based on coefficient at +// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be +// updated as they have already been eliminated. +static void eliminateFromConstraint(FlatAffineConstraints *constraints, + unsigned rowIdx, unsigned pivotRow, + unsigned pivotCol, unsigned elimColStart, + bool isEq) { + // Skip if equality 'rowIdx' if same as 'pivotRow'. + if (isEq && rowIdx == pivotRow) + return; + auto at = [&](unsigned i, unsigned j) -> int64_t { + return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); + }; + int64_t leadCoeff = at(rowIdx, pivotCol); + // Skip if leading coefficient at 'rowIdx' is already zero. + if (leadCoeff == 0) + return; + int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); + int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; + int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); + int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); + int64_t rowMultiplier = lcm / std::abs(leadCoeff); + + unsigned numCols = constraints->getNumCols(); + for (unsigned j = 0; j < numCols; ++j) { + // Skip updating column 'j' if it was just eliminated. + if (j >= elimColStart && j < pivotCol) + continue; + int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + + rowMultiplier * at(rowIdx, j); + isEq ? constraints->atEq(rowIdx, j) = v + : constraints->atIneq(rowIdx, j) = v; + } +} + +// Remove coefficients in column range [colStart, colLimit) in place. +// This removes in data in the specified column range, and copies any +// remaining valid data into place. +static void removeColumns(FlatAffineConstraints *constraints, unsigned colStart, + unsigned colLimit, bool isEq) { + unsigned numCols = constraints->getNumCols(); + unsigned newNumCols = numCols - (colLimit - colStart); + unsigned numRows = isEq ? constraints->getNumEqualities() + : constraints->getNumInequalities(); + for (unsigned i = 0, e = numRows; i < e; ++i) { + for (unsigned j = 0; j < numCols; ++j) { + if (j >= colStart && j < colLimit) + continue; + unsigned inputIndex = i * numCols + j; + unsigned outputOffset = j >= colLimit ? j - (colLimit - colStart) : j; + unsigned outputIndex = i * newNumCols + outputOffset; + assert(outputIndex <= inputIndex); + if (isEq) { + constraints->atEqIdx(outputIndex) = constraints->atEqIdx(inputIndex); + } else { + constraints->atIneqIdx(outputIndex) = + constraints->atIneqIdx(inputIndex); + } + } + } +} + +// Removes coefficients in column range [colStart, colLimit),and copies any +// remaining valid data into place, updates member variables, and resizes +// arrays as needed. +void FlatAffineConstraints::removeColumnRange(unsigned colStart, + unsigned colLimit) { + // TODO(andydavis) Make 'removeColumns' a lambda called from here. + // Remove eliminated columns from equalities. + removeColumns(this, colStart, colLimit, /*isEq=*/true); + // Remove eliminated columns from inequalities. + removeColumns(this, colStart, colLimit, /*isEq=*/false); + // Update members numDims, numSymbols and numIds. + unsigned numDimsEliminated = 0; + if (colStart < numDims) { + numDimsEliminated = std::min(numDims, colLimit) - colStart; + } + unsigned numEqualities = getNumEqualities(); + unsigned numInequalities = getNumInequalities(); + unsigned numColsEliminated = colLimit - colStart; + unsigned numSymbolsEliminated = + std::min(numSymbols, numColsEliminated - numDimsEliminated); + numDims -= numDimsEliminated; + numSymbols -= numSymbolsEliminated; + numIds = numIds - numColsEliminated; + equalities.resize(numEqualities * getNumCols()); + inequalities.resize(numInequalities * getNumCols()); +} + +// Performs variable elimination on all identifiers, runs the GCD test on +// all equality constraint rows, and checks the constraint validity. +// Returns 'true' if the GCD test fails on any row, or if any invalid +// constraint is detected. Returns 'false' otherwise. +bool FlatAffineConstraints::isEmpty() { + if (eliminateIdentifiers(0, numIds) == 0) + return false; + if (isEmptyByGCDTest(*this)) + return true; + if (hasInvalidConstraint(*this)) + return true; + return false; +} + +// Eliminates a single identifier at 'position' from equality and inequality +// constraints. Returns 'true' if the identifier was eliminated. +// Returns 'false' otherwise. +bool FlatAffineConstraints::eliminateIdentifier(unsigned position) { + return eliminateIdentifiers(position, position + 1) == 1; +} + +// Eliminates all identifer variables in column range [posStart, posLimit). +// Returns the number of variables eliminated. +unsigned FlatAffineConstraints::eliminateIdentifiers(unsigned posStart, + unsigned posLimit) { + // Return if identifier positions to eliminate are out of range. + if (posStart >= posLimit || posLimit > numIds) + return 0; + unsigned pivotCol = 0; + for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { + // Find a row which has a non-zero coefficient in column 'j'. + unsigned pivotRow; + if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true, + pivotRow)) { + // No pivot row in equalities with non-zero at 'pivotCol'. + if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false, + pivotRow)) { + // If inequalities are also non-zero in 'pivotCol' it can be eliminated. + continue; + } + break; + } + + // Eliminate identifier at 'pivotCol' from each equality row. + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + /*isEq=*/true); + normalizeConstraintByGCD(this, i, /*isEq=*/true); + } + + // Eliminate identifier at 'pivotCol' from each inequality row. + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + /*isEq=*/false); + normalizeConstraintByGCD(this, i, /*isEq=*/false); + } + removeEquality(pivotRow); + } + // Update position limit based on number eliminated. + posLimit = pivotCol; + // Remove eliminated columns from all constraints. + removeColumnRange(posStart, posLimit); + return posLimit - posStart; +} + void FlatAffineConstraints::addEquality(ArrayRef eq) { assert(eq.size() == getNumCols()); unsigned offset = equalities.size(); @@ -446,3 +727,44 @@ void FlatAffineConstraints::addEquality(ArrayRef eq) { equalities[offset + i] = eq[i]; } } + +void FlatAffineConstraints::removeEquality(unsigned pos) { + unsigned numEqualities = getNumEqualities(); + assert(pos < numEqualities); + unsigned numCols = getNumCols(); + unsigned outputIndex = pos * numCols; + unsigned inputIndex = (pos + 1) * numCols; + unsigned numElemsToCopy = (numEqualities - pos - 1) * numCols; + for (unsigned i = 0; i < numElemsToCopy; ++i) { + equalities[outputIndex + i] = equalities[inputIndex + i]; + } + equalities.resize(equalities.size() - numCols); +} + +void FlatAffineConstraints::addInequality(ArrayRef inEq) { + assert(inEq.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + inEq.size()); + for (unsigned i = 0, e = inEq.size(); i < e; i++) { + inequalities[offset + i] = inEq[i]; + } +} + +void FlatAffineConstraints::print(raw_ostream &os) const { + os << "\nConstraints:\n"; + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + for (unsigned j = 0; j < getNumCols(); ++j) { + os << atEq(i, j) << " "; + } + os << "= 0\n"; + } + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + for (unsigned j = 0; j < getNumCols(); ++j) { + os << atIneq(i, j) << " "; + } + os << ">= 0\n"; + } + os << '\n'; +} + +void FlatAffineConstraints::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/IR/IntegerSet.cpp b/mlir/lib/IR/IntegerSet.cpp index 889bdd403af3..ff70128ad641 100644 --- a/mlir/lib/IR/IntegerSet.cpp +++ b/mlir/lib/IR/IntegerSet.cpp @@ -29,6 +29,18 @@ unsigned IntegerSet::getNumOperands() const { } unsigned IntegerSet::getNumConstraints() const { return set->numConstraints; } +unsigned IntegerSet::getNumEqualities() const { + unsigned numEqualities = 0; + for (unsigned i = 0, e = getNumConstraints(); i < e; i++) + if (isEq(i)) + ++numEqualities; + return numEqualities; +} + +unsigned IntegerSet::getNumInequalities() const { + return getNumConstraints() - getNumEqualities(); +} + ArrayRef IntegerSet::getConstraints() const { return set->constraints; } @@ -44,3 +56,7 @@ ArrayRef IntegerSet::getEqFlags() const { return set->eqFlags; } /// Returns true if the idx^th constraint is an equality, false if it is an /// inequality. bool IntegerSet::isEq(unsigned idx) const { return getEqFlags()[idx]; } + +MLIRContext *IntegerSet::getContext() const { + return getConstraint(0).getContext(); +} diff --git a/mlir/lib/Transforms/SimplifyAffineExpr.cpp b/mlir/lib/Transforms/SimplifyAffineExpr.cpp index 478666eb66d3..edb60c9bf232 100644 --- a/mlir/lib/Transforms/SimplifyAffineExpr.cpp +++ b/mlir/lib/Transforms/SimplifyAffineExpr.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/MLFunction.h" #include "mlir/IR/Statements.h" +#include "mlir/IR/StmtVisitor.h" #include "mlir/Transforms/Pass.h" #include "mlir/Transforms/Passes.h" @@ -36,32 +37,51 @@ namespace { /// the MLFunction. This is mainly to test the simplifyAffineExpr method. // TODO(someone): Gradually, extend this to all affine map references found in // ML functions and CFG functions. -struct SimplifyAffineExpr : public FunctionPass { - explicit SimplifyAffineExpr() {} +struct SimplifyAffineStructures : public FunctionPass, + StmtWalker { + explicit SimplifyAffineStructures() {} PassResult runOnMLFunction(MLFunction *f); // Does nothing on CFG functions for now. No reusable walkers/visitors exist // for this yet? TODO(someone). PassResult runOnCFGFunction(CFGFunction *f) { return success(); } + + void visitOperationStmt(OperationStmt *stmt); + void visitIfStmt(IfStmt *ifStmt); }; } // end anonymous namespace FunctionPass *mlir::createSimplifyAffineExprPass() { - return new SimplifyAffineExpr(); + return new SimplifyAffineStructures(); } -PassResult SimplifyAffineExpr::runOnMLFunction(MLFunction *f) { - f->walkPostOrder([&](OperationStmt *opStmt) { - for (auto attr : opStmt->getAttrs()) { - if (auto *mapAttr = dyn_cast(attr.second)) { - MutableAffineMap mMap(mapAttr->getValue()); - mMap.simplify(); - auto map = mMap.getAffineMap(); - opStmt->setAttr(attr.first, AffineMapAttr::get(map)); - } - } - }); +static IntegerSet simplifyIntegerSet(IntegerSet set) { + FlatAffineConstraints fac(set); + if (fac.isEmpty()) + return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(), + set.getContext()); + return set; +} +void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) { + auto set = ifStmt->getCondition().getSet(); + IntegerSet simplified = simplifyIntegerSet(set); + ifStmt->setIntegerSet(simplified); +} + +void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) { + for (auto attr : opStmt->getAttrs()) { + if (auto *mapAttr = dyn_cast(attr.second)) { + MutableAffineMap mMap(mapAttr->getValue()); + mMap.simplify(); + auto map = mMap.getAffineMap(); + opStmt->setAttr(attr.first, AffineMapAttr::get(map)); + } + } +} + +PassResult SimplifyAffineStructures::runOnMLFunction(MLFunction *f) { + walk(f); return success(); } diff --git a/mlir/test/Transforms/simplify.mlir b/mlir/test/Transforms/simplify.mlir index 904ba43b788e..5ac9cf8f392c 100644 --- a/mlir/test/Transforms/simplify.mlir +++ b/mlir/test/Transforms/simplify.mlir @@ -22,6 +22,38 @@ // CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 - (d0 floordiv 8) * 8, (d1 floordiv 8) * 8) #map6 = (d0, d1) -> (d0 mod 8, d1 - d1 mod 8) +// Set for test case: test_gaussian_elimination_empty_set0 +// CHECK: @@set0 = (d0, d1) : (1 == 0) +@@set0 = (d0, d1) : (2 == 0) + +// Set for test case: test_gaussian_elimination_empty_set1 +// CHECK: @@set1 = (d0, d1) : (1 == 0) +@@set1 = (d0, d1) : (1 >= 0, -1 >= 0) + +// Set for test case: test_gaussian_elimination_non_empty_set2 +// CHECK: @@set2 = (d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, d0 * -1 + 100 >= 0, d1 >= 0, d1 + 101 >= 0) +@@set2 = (d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0, d1 + 101 >= 0) + +// Set for test case: test_gaussian_elimination_empty_set3 +// CHECK: @@set3 = (d0, d1)[s0, s1] : (1 == 0) +@@set3 = (d0, d1)[s0, s1] : (d0 - s0 == 0, d0 + s0 == 0, s0 - 1 == 0) + +// Set for test case: test_gaussian_elimination_non_empty_set4 +// CHECK: @@set4 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0) +@@set4 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, + d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, + d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, + d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0) + +// Add invalide constraints to previous non-empty set to make it empty. +// Set for test case: test_gaussian_elimination_empty_set5 +// CHECK: @@set5 = (d0, d1)[s0, s1] : (1 == 0) +@@set5 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, + d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, + d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, + d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, + d0 - 1 == 0, d0 + 2 == 0) + mlfunc @test() { for %n0 = 0 to 127 { for %n1 = 0 to 7 { @@ -37,3 +69,80 @@ mlfunc @test() { return } +// CHECK-LABEL: mlfunc @test_gaussian_elimination_empty_set0() { +mlfunc @test_gaussian_elimination_empty_set0() { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { + // CHECK: @@set0(%i0, %i1) + if @@set0(%i0, %i1) { + } + } + } + return +} + +// CHECK-LABEL: mlfunc @test_gaussian_elimination_empty_set1() { +mlfunc @test_gaussian_elimination_empty_set1() { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { + // CHECK: @@set1(%i0, %i1) + if @@set1(%i0, %i1) { + } + } + } + return +} + +// CHECK-LABEL: mlfunc @test_gaussian_elimination_non_empty_set2() { +mlfunc @test_gaussian_elimination_non_empty_set2() { + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { + // CHECK: @@set2(%i0, %i1) + if @@set2(%i0, %i1) { + } + } + } + return +} + +// CHECK-LABEL: mlfunc @test_gaussian_elimination_empty_set3() { +mlfunc @test_gaussian_elimination_empty_set3() { + %c7 = constant 7 : index + %c11 = constant 11 : index + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { + // CHECK: @@set3(%i0, %i1)[%c7, %c11] + if @@set3(%i0, %i1)[%c7, %c11] { + } + } + } + return +} + +// CHECK-LABEL: mlfunc @test_gaussian_elimination_non_empty_set4() { +mlfunc @test_gaussian_elimination_non_empty_set4() { + %c7 = constant 7 : index + %c11 = constant 11 : index + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { + // CHECK: @@set4(%i0, %i1)[%c7, %c11] + if @@set4(%i0, %i1)[%c7, %c11] { + } + } + } + return +} + +// CHECK-LABEL: mlfunc @test_gaussian_elimination_empty_set5() { +mlfunc @test_gaussian_elimination_empty_set5() { + %c7 = constant 7 : index + %c11 = constant 11 : index + for %i0 = 1 to 10 { + for %i1 = 1 to 100 { + // CHECK: @@set5(%i0, %i1)[%c7, %c11] + if @@set5(%i0, %i1)[%c7, %c11] { + } + } + } + return +} \ No newline at end of file