[mlir][SCF] Adding custom builder to SCF::WhileOp.

This is a similar builder to the one for SCF::IfOp which allows users to pass region builders to it. Refer to the builders for IfOp.

Reviewed By: tpopp

Differential Revision: https://reviews.llvm.org/D137709
This commit is contained in:
Mohammed Anany
2022-11-15 18:10:17 +01:00
committed by Benjamin Kramer
parent beaffb041c
commit 77533d79f7
3 changed files with 61 additions and 32 deletions

View File

@@ -935,7 +935,7 @@ def WhileOp : SCF_Op<"while",
Note that the types of region arguments need not to match with each other.
The op expects the operand types to match with argument types of the
"before" region"; the result types to match with the trailing operand types
"before" region; the result types to match with the trailing operand types
of the terminator of the "before" region, and with the argument types of the
"after" region. The following scheme can be used to share the results of
some operations executed in the "before" region with the "after" region,
@@ -983,7 +983,16 @@ def WhileOp : SCF_Op<"while",
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
let builders = [
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
"function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
"function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
];
let extraClassDeclaration = [{
using BodyBuilderFn =
function_ref<void(OpBuilder &, Location, ValueRange)>;
OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
ConditionOp getConditionOp();
YieldOp getYieldOp();

View File

@@ -71,40 +71,32 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
SmallVector<Type> types = {elementTy, elementTy, elementTy};
SmallVector<Location> locations = {loc, loc, loc};
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
Block *before =
rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
Block *after =
rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
auto whileOp = rewriter.create<scf::WhileOp>(
loc, types, operands,
[&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) {
// The conditional block of the while loop.
Value input = args[0];
Value zero = args[2];
// The conditional block of the while loop.
{
rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
Value input = before->getArgument(0);
Value zero = before->getArgument(2);
Value inputNotZero = beforeBuilder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, input, zero);
beforeBuilder.create<scf::ConditionOp>(loc, inputNotZero, args);
},
[&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) {
// The body of the while loop: shift right until reaching a value of 0.
Value input = args[0];
Value leadingZeros = args[1];
Value inputNotZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, input, zero);
rewriter.create<scf::ConditionOp>(loc, inputNotZero,
before->getArguments());
}
auto one = afterBuilder.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, 1));
auto shifted =
afterBuilder.create<arith::ShRUIOp>(loc, resultTy, input, one);
auto leadingZerosMinusOne = afterBuilder.create<arith::SubIOp>(
loc, resultTy, leadingZeros, one);
// The body of the while loop: shift right until reaching a value of 0.
{
rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
Value input = after->getArgument(0);
Value leadingZeros = after->getArgument(1);
auto one =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
auto leadingZerosMinusOne =
rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
rewriter.create<scf::YieldOp>(
loc,
ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
}
afterBuilder.create<scf::YieldOp>(
loc, ValueRange({shifted, leadingZerosMinusOne, args[2]}));
});
rewriter.setInsertionPointAfter(whileOp);
rewriter.replaceOp(op, whileOp->getResult(1));

View File

@@ -2669,6 +2669,34 @@ LogicalResult ReduceReturnOp::verify() {
// WhileOp
//===----------------------------------------------------------------------===//
void WhileOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, TypeRange resultTypes,
ValueRange operands, BodyBuilderFn beforeBuilder,
BodyBuilderFn afterBuilder) {
assert(beforeBuilder && "the builder callback for 'before' must be present");
assert(afterBuilder && "the builder callback for 'after' must be present");
odsState.addOperands(operands);
odsState.addTypes(resultTypes);
OpBuilder::InsertionGuard guard(odsBuilder);
SmallVector<Location, 4> blockArgLocs;
for (Value operand : operands) {
blockArgLocs.push_back(operand.getLoc());
}
Region *beforeRegion = odsState.addRegion();
Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
resultTypes, blockArgLocs);
beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
Region *afterRegion = odsState.addRegion();
Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
resultTypes, blockArgLocs);
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}
OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index == 0 &&
"WhileOp is expected to branch only to the first region");