mirror of
https://github.com/intel/llvm.git
synced 2026-02-09 01:52:26 +08:00
[mlir][SCF] Adding custom builder to SCF::WhileOp.
This is a similar builder to the one for SCF::IfOp which allows users to pass region builders to it. Refer to the builders for IfOp. Reviewed By: tpopp Differential Revision: https://reviews.llvm.org/D137709
This commit is contained in:
committed by
Benjamin Kramer
parent
beaffb041c
commit
77533d79f7
@@ -935,7 +935,7 @@ def WhileOp : SCF_Op<"while",
|
||||
|
||||
Note that the types of region arguments need not to match with each other.
|
||||
The op expects the operand types to match with argument types of the
|
||||
"before" region"; the result types to match with the trailing operand types
|
||||
"before" region; the result types to match with the trailing operand types
|
||||
of the terminator of the "before" region, and with the argument types of the
|
||||
"after" region. The following scheme can be used to share the results of
|
||||
some operations executed in the "before" region with the "after" region,
|
||||
@@ -983,7 +983,16 @@ def WhileOp : SCF_Op<"while",
|
||||
let results = (outs Variadic<AnyType>:$results);
|
||||
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
|
||||
"function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
|
||||
"function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
using BodyBuilderFn =
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)>;
|
||||
|
||||
OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
|
||||
ConditionOp getConditionOp();
|
||||
YieldOp getYieldOp();
|
||||
|
||||
@@ -71,40 +71,32 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
|
||||
SmallVector<Type> types = {elementTy, elementTy, elementTy};
|
||||
SmallVector<Location> locations = {loc, loc, loc};
|
||||
|
||||
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
|
||||
Block *before =
|
||||
rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
|
||||
Block *after =
|
||||
rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
|
||||
auto whileOp = rewriter.create<scf::WhileOp>(
|
||||
loc, types, operands,
|
||||
[&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) {
|
||||
// The conditional block of the while loop.
|
||||
Value input = args[0];
|
||||
Value zero = args[2];
|
||||
|
||||
// The conditional block of the while loop.
|
||||
{
|
||||
rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
|
||||
Value input = before->getArgument(0);
|
||||
Value zero = before->getArgument(2);
|
||||
Value inputNotZero = beforeBuilder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ne, input, zero);
|
||||
beforeBuilder.create<scf::ConditionOp>(loc, inputNotZero, args);
|
||||
},
|
||||
[&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) {
|
||||
// The body of the while loop: shift right until reaching a value of 0.
|
||||
Value input = args[0];
|
||||
Value leadingZeros = args[1];
|
||||
|
||||
Value inputNotZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ne, input, zero);
|
||||
rewriter.create<scf::ConditionOp>(loc, inputNotZero,
|
||||
before->getArguments());
|
||||
}
|
||||
auto one = afterBuilder.create<arith::ConstantOp>(
|
||||
loc, IntegerAttr::get(elementTy, 1));
|
||||
auto shifted =
|
||||
afterBuilder.create<arith::ShRUIOp>(loc, resultTy, input, one);
|
||||
auto leadingZerosMinusOne = afterBuilder.create<arith::SubIOp>(
|
||||
loc, resultTy, leadingZeros, one);
|
||||
|
||||
// The body of the while loop: shift right until reaching a value of 0.
|
||||
{
|
||||
rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
|
||||
Value input = after->getArgument(0);
|
||||
Value leadingZeros = after->getArgument(1);
|
||||
|
||||
auto one =
|
||||
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
|
||||
auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
|
||||
auto leadingZerosMinusOne =
|
||||
rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
|
||||
|
||||
rewriter.create<scf::YieldOp>(
|
||||
loc,
|
||||
ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
|
||||
}
|
||||
afterBuilder.create<scf::YieldOp>(
|
||||
loc, ValueRange({shifted, leadingZerosMinusOne, args[2]}));
|
||||
});
|
||||
|
||||
rewriter.setInsertionPointAfter(whileOp);
|
||||
rewriter.replaceOp(op, whileOp->getResult(1));
|
||||
|
||||
@@ -2669,6 +2669,34 @@ LogicalResult ReduceReturnOp::verify() {
|
||||
// WhileOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void WhileOp::build(::mlir::OpBuilder &odsBuilder,
|
||||
::mlir::OperationState &odsState, TypeRange resultTypes,
|
||||
ValueRange operands, BodyBuilderFn beforeBuilder,
|
||||
BodyBuilderFn afterBuilder) {
|
||||
assert(beforeBuilder && "the builder callback for 'before' must be present");
|
||||
assert(afterBuilder && "the builder callback for 'after' must be present");
|
||||
|
||||
odsState.addOperands(operands);
|
||||
odsState.addTypes(resultTypes);
|
||||
|
||||
OpBuilder::InsertionGuard guard(odsBuilder);
|
||||
|
||||
SmallVector<Location, 4> blockArgLocs;
|
||||
for (Value operand : operands) {
|
||||
blockArgLocs.push_back(operand.getLoc());
|
||||
}
|
||||
|
||||
Region *beforeRegion = odsState.addRegion();
|
||||
Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
|
||||
resultTypes, blockArgLocs);
|
||||
beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
|
||||
|
||||
Region *afterRegion = odsState.addRegion();
|
||||
Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
|
||||
resultTypes, blockArgLocs);
|
||||
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
|
||||
}
|
||||
|
||||
OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
|
||||
assert(index && *index == 0 &&
|
||||
"WhileOp is expected to branch only to the first region");
|
||||
|
||||
Reference in New Issue
Block a user