[mlir][SCF] Avoid generating unnecessary div/rem operations during coalescing (#91562)

When coalescing is some of the loops are unit-trip we can avoid
generating div/rem instructions during delinearization. Ideally we could
use some thing like `affine.delinearize` to handle this but tthat causes
dependence issues.
This commit is contained in:
MaheshRavishankar
2024-05-09 14:54:38 -07:00
committed by GitHub
parent 8466480bda
commit 04ce10357b
2 changed files with 124 additions and 10 deletions

View File

@@ -544,11 +544,24 @@ static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values) {
assert(!values.empty() && "unexpected empty list");
Value productOf = values.front();
for (auto v : values.drop_front()) {
productOf = rewriter.create<arith::MulIOp>(loc, productOf, v);
std::optional<Value> productOf;
for (auto v : values) {
auto vOne = getConstantIntValue(v);
if (vOne && vOne.value() == 1)
continue;
if (productOf)
productOf =
rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult();
else
productOf = v;
}
return productOf;
if (!productOf) {
productOf = rewriter
.create<arith::ConstantOp>(
loc, rewriter.getOneAttr(values.front().getType()))
.getResult();
}
return productOf.value();
}
/// For each original loop, the value of the
@@ -562,19 +575,43 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
Value linearizedIv, ArrayRef<Value> ubs) {
Value previous = linearizedIv;
SmallVector<Value> delinearizedIvs(ubs.size());
SmallPtrSet<Operation *, 2> preservedUsers;
for (unsigned i = 0, e = ubs.size(); i < e; ++i) {
unsigned idx = ubs.size() - i - 1;
if (i != 0) {
llvm::BitVector isUbOne(ubs.size());
for (auto [index, ub] : llvm::enumerate(ubs)) {
auto ubCst = getConstantIntValue(ub);
if (ubCst && ubCst.value() == 1)
isUbOne.set(index);
}
// Prune the lead ubs that are all ones.
unsigned numLeadingOneUbs = 0;
for (auto [index, ub] : llvm::enumerate(ubs)) {
if (!isUbOne.test(index)) {
break;
}
delinearizedIvs[index] = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(ub.getType()));
numLeadingOneUbs++;
}
Value previous = linearizedIv;
for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
preservedUsers.insert(previous.getDefiningOp());
}
Value iv = previous;
if (i != e - 1) {
iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
preservedUsers.insert(iv.getDefiningOp());
if (!isUbOne.test(idx)) {
iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
preservedUsers.insert(iv.getDefiningOp());
} else {
iv = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(ubs[idx].getType()));
}
}
delinearizedIvs[idx] = iv;
}

View File

@@ -299,3 +299,80 @@ module attributes {transform.with_named_sequence} {
// CHECK-NOT: scf.for
// CHECK: transform.named_sequence
// -----
// Check avoiding generating unnecessary operations while collapsing trip-1 loops.
func.func @trip_one_loops(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args(%iter0 = %arg0) -> tensor<?x?xf32> {
%1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args(%iter1 = %iter0) -> tensor<?x?xf32> {
%2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args(%iter2 = %iter1) -> tensor<?x?xf32> {
%3 = scf.for %iv3 = %c0 to %c1 step %c1 iter_args(%iter3 = %iter2) -> tensor<?x?xf32> {
%4 = scf.for %iv4 = %c0 to %arg2 step %c1 iter_args(%iter4 = %iter3) -> tensor<?x?xf32> {
%5 = "some_use"(%iter4, %iv0, %iv1, %iv2, %iv3, %iv4)
: (tensor<?x?xf32>, index, index, index, index, index) -> (tensor<?x?xf32>)
scf.yield %5 : tensor<?x?xf32>
}
scf.yield %4 : tensor<?x?xf32>
}
scf.yield %3 : tensor<?x?xf32>
}
scf.yield %2 : tensor<?x?xf32>
}
scf.yield %1 : tensor<?x?xf32>
} {coalesce}
return %0 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
%2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
transform.yield
}
}
// CHECK-LABEL: func @trip_one_loops
// CHECK-SAME: , %[[ARG1:.+]]: index,
// CHECK-SAME: %[[ARG2:.+]]: index)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[UB:.+]] = arith.muli %[[ARG1]], %[[ARG2]]
// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[UB]] step %[[C1]]
// CHECK: %[[IV1:.+]] = arith.remsi %[[IV]], %[[ARG2]]
// CHECK: %[[IV2:.+]] = arith.divsi %[[IV]], %[[ARG2]]
// CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV2]], %[[C0]], %[[IV1]])
// -----
// Check generating no instructions when all except one loops is non unit-trip.
func.func @all_outer_trip_one(%arg0 : tensor<?x?xf32>, %arg1 : index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args(%iter0 = %arg0) -> tensor<?x?xf32> {
%1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args(%iter1 = %iter0) -> tensor<?x?xf32> {
%2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args(%iter2 = %iter1) -> tensor<?x?xf32> {
%3 = "some_use"(%iter2, %iv0, %iv1, %iv2)
: (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>)
scf.yield %3 : tensor<?x?xf32>
}
scf.yield %2 : tensor<?x?xf32>
}
scf.yield %1 : tensor<?x?xf32>
} {coalesce}
return %0 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
%2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
transform.yield
}
}
// CHECK-LABEL: func @all_outer_trip_one
// CHECK-SAME: , %[[ARG1:.+]]: index)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[ARG1]] step %[[C1]]
// CHECK: "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV]])