[MLIR][Presburger] support IntegerRelation::convertIdKind

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D122154
This commit is contained in:
Arjun P
2022-03-23 11:10:48 +00:00
parent 8b62dd3cd6
commit 87cffeb635
4 changed files with 39 additions and 25 deletions

View File

@@ -390,9 +390,14 @@ public:
/// O(VC) time.
void removeRedundantConstraints();
/// Converts identifiers in the column range [idStart, idLimit) to local
/// variables.
void convertDimToLocal(unsigned dimStart, unsigned dimLimit);
/// Converts identifiers of kind srcKind in the range [idStart, idLimit) to
/// variables of kind dstKind and placed after all the other variables of kind
/// dstKind. The internal ordering among the moved variables is preserved.
void convertIdKind(IdKind srcKind, unsigned idStart, unsigned idLimit,
IdKind dstKind);
void convertToLocal(IdKind kind, unsigned idStart, unsigned idLimit) {
convertIdKind(kind, idStart, idLimit, IdKind::Local);
}
/// Adds additional local ids to the sets such that they both have the union
/// of the local ids in each set, without changing the set of points that

View File

@@ -1117,23 +1117,31 @@ void IntegerRelation::removeRedundantLocalVars() {
}
}
void IntegerRelation::convertDimToLocal(unsigned dimStart, unsigned dimLimit) {
assert(dimLimit <= getNumDimIds() && "Invalid dim pos range");
void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart,
unsigned idLimit, IdKind dstKind) {
assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range");
if (dimStart >= dimLimit)
if (idStart >= idLimit)
return;
// Append new local variables corresponding to the dimensions to be converted.
unsigned convertCount = dimLimit - dimStart;
unsigned newLocalIdStart = getNumIds();
appendId(IdKind::Local, convertCount);
unsigned newIdsBegin = getIdKindEnd(dstKind);
unsigned convertCount = idLimit - idStart;
appendId(dstKind, convertCount);
// Swap the new local variables with dimensions.
//
// Essentially, this moves the information corresponding to the specified ids
// of kind `srcKind` to the `convertCount` newly created ids of kind
// `dstKind`. In particular, this moves the columns in the constraint
// matrices, and zeros out the initially occupied columns (because the newly
// created ids we're swapping with were zero-initialized).
unsigned offset = getIdKindOffset(srcKind);
for (unsigned i = 0; i < convertCount; ++i)
swapId(i + dimStart, i + newLocalIdStart);
swapId(offset + idStart + i, newIdsBegin + i);
// Remove dimensions converted to local variables.
removeIdRange(IdKind::SetDim, dimStart, dimLimit);
// Complete the move by deleting the initially occupied columns.
removeIdRange(srcKind, idStart, idLimit);
}
void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) {

View File

@@ -1618,15 +1618,15 @@ AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const {
FlatAffineValueConstraints domain = *this;
// Convert all range variables to local variables.
domain.convertDimToLocal(getNumDomainDims(),
getNumDomainDims() + getNumRangeDims());
domain.convertToLocal(IdKind::SetDim, getNumDomainDims(),
getNumDomainDims() + getNumRangeDims());
return domain;
}
FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const {
FlatAffineValueConstraints range = *this;
// Convert all domain variables to local variables.
range.convertDimToLocal(0, getNumDomainDims());
range.convertToLocal(IdKind::SetDim, 0, getNumDomainDims());
return range;
}
@@ -1658,12 +1658,13 @@ void FlatAffineRelation::compose(const FlatAffineRelation &other) {
// Convert `rel` from [otherDomain] -> [otherRange thisRange] to
// [otherDomain] -> [thisRange] by converting first otherRange range ids
// to local ids.
rel.convertDimToLocal(rel.getNumDomainDims(),
rel.getNumDomainDims() + removeDims);
rel.convertToLocal(IdKind::SetDim, rel.getNumDomainDims(),
rel.getNumDomainDims() + removeDims);
// Convert `this` from [otherDomain thisDomain] -> [thisRange] to
// [otherDomain] -> [thisRange] by converting last thisDomain domain ids
// to local ids.
convertDimToLocal(getNumDomainDims() - removeDims, getNumDomainDims());
convertToLocal(IdKind::SetDim, getNumDomainDims() - removeDims,
getNumDomainDims());
auto thisMaybeValues = getMaybeDimValues();
auto relMaybeValues = rel.getMaybeDimValues();

View File

@@ -707,7 +707,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprTightUpperBound) {
IntegerPolyhedron poly =
parsePoly("(i, j, q) : (4*q - i - j + 2 >= 0, -4*q + i + j >= 0)");
// Convert `q` to a local variable.
poly.convertDimToLocal(2, 3);
poly.convertToLocal(IdKind::SetDim, 2, 3);
std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 1}};
SmallVector<unsigned, 8> denoms = {4};
@@ -721,7 +721,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEquality) {
{
IntegerPolyhedron poly = parsePoly("(i, j, q) : (-4*q + i + j == 0)");
// Convert `q` to a local variable.
poly.convertDimToLocal(2, 3);
poly.convertToLocal(IdKind::SetDim, 2, 3);
std::vector<SmallVector<int64_t, 8>> divisions = {{-1, -1, 0, 0}};
SmallVector<unsigned, 8> denoms = {4};
@@ -731,7 +731,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEquality) {
{
IntegerPolyhedron poly = parsePoly("(i, j, q) : (4*q - i - j == 0)");
// Convert `q` to a local variable.
poly.convertDimToLocal(2, 3);
poly.convertToLocal(IdKind::SetDim, 2, 3);
std::vector<SmallVector<int64_t, 8>> divisions = {{-1, -1, 0, 0}};
SmallVector<unsigned, 8> denoms = {4};
@@ -741,7 +741,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEquality) {
{
IntegerPolyhedron poly = parsePoly("(i, j, q) : (3*q + i + j - 2 == 0)");
// Convert `q` to a local variable.
poly.convertDimToLocal(2, 3);
poly.convertToLocal(IdKind::SetDim, 2, 3);
std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, -2}};
SmallVector<unsigned, 8> denoms = {3};
@@ -756,7 +756,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEqualityAndInequality) {
parsePoly("(i, j, q, k) : (-3*k + i + j == 0, 4*q - "
"i - j + 2 >= 0, -4*q + i + j >= 0)");
// Convert `q` and `k` to local variables.
poly.convertDimToLocal(2, 4);
poly.convertToLocal(IdKind::SetDim, 2, 4);
std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 0, 1},
{-1, -1, 0, 0, 0}};
@@ -770,7 +770,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprNoRepr) {
IntegerPolyhedron poly =
parsePoly("(x, q) : (x - 3 * q >= 0, -x + 3 * q + 3 >= 0)");
// Convert q to a local variable.
poly.convertDimToLocal(1, 2);
poly.convertToLocal(IdKind::SetDim, 1, 2);
std::vector<SmallVector<int64_t, 8>> divisions = {{0, 0, 0}};
SmallVector<unsigned, 8> denoms = {0};
@@ -783,7 +783,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprNegConstNormalize) {
IntegerPolyhedron poly =
parsePoly("(x, q) : (-1 - 3*x - 6 * q >= 0, 6 + 3*x + 6*q >= 0)");
// Convert q to a local variable.
poly.convertDimToLocal(1, 2);
poly.convertToLocal(IdKind::SetDim, 1, 2);
// q = floor((-1/3 - x)/2)
// = floor((1/3) + (-1 - x)/2)