Remove the 'region' field from OpBuilder.

This field wasn't updated as the insertion point changed, making it potentially dangerous given the multi-level of MLIR(e.g. 'createBlock' would always insert the new block in 'region'). This also allows for building an OpBuilder with just a context.

PiperOrigin-RevId: 257829135
This commit is contained in:
River Riddle
2019-07-12 10:43:11 -07:00
committed by Mehdi Amini
parent 8956838930
commit 8e349a48b6
6 changed files with 50 additions and 41 deletions

View File

@@ -433,7 +433,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
// Insert a `dealloc` operation right before the `return` operations, unless
// it is returned itself in which case the caller is responsible for it.
builder.getRegion()->walk([&](Operation *op) {
alloc.getContainingRegion()->walk([&](Operation *op) {
auto returnOp = dyn_cast<ReturnOp>(op);
if (!returnOp)
return;

View File

@@ -181,13 +181,13 @@ protected:
/// automatically inserted at an insertion point. The builder is copyable.
class OpBuilder : public Builder {
public:
/// Create a builder with the given context.
explicit OpBuilder(MLIRContext *ctx) : Builder(ctx) {}
/// Create a builder and set the insertion point to the start of the region.
explicit OpBuilder(Region *region)
: Builder(region->getContext()), region(region) {
explicit OpBuilder(Region *region) : Builder(region->getContext()) {
if (!region->empty())
setInsertionPoint(&region->front(), region->front().begin());
else
clearInsertionPoint();
}
explicit OpBuilder(Region &region) : OpBuilder(&region) {}
@@ -195,20 +195,17 @@ public:
/// Create a builder and set insertion point to the given operation, which
/// will cause subsequent insertions to go right before it.
OpBuilder(Operation *op) : OpBuilder(op->getContainingRegion()) {
explicit OpBuilder(Operation *op) : Builder(op->getContext()) {
setInsertionPoint(op);
}
OpBuilder(Block *block) : OpBuilder(block, block->end()) {}
explicit OpBuilder(Block *block) : OpBuilder(block, block->end()) {}
OpBuilder(Block *block, Block::iterator insertPoint)
: OpBuilder(block->getParent()) {
setInsertionPoint(block, insertPoint);
}
/// Return the region this builder is referring to.
Region *getRegion() const { return region; }
/// This class represents a saved insertion point.
class InsertPoint {
public:
@@ -281,11 +278,13 @@ public:
/// Returns the current insertion point of the builder.
Block::iterator getInsertionPoint() const { return insertPoint; }
/// Add new block and set the insertion point to the end of it. If an
/// 'insertBefore' block is passed, the block will be placed before the
/// specified block. If not, the block will be appended to the end of the
/// current region.
Block *createBlock(Block *insertBefore = nullptr);
/// Add new block and set the insertion point to the end of it. The block is
/// inserted at the provided insertion point of 'parent'.
Block *createBlock(Region *parent, Region::iterator insertPt = {});
/// Add new block and set the insertion point to the end of it. The block is
/// placed before 'insertBefore'.
Block *createBlock(Block *insertBefore);
/// Returns the current block of the builder.
Block *getBlock() const { return block; }
@@ -345,12 +344,12 @@ public:
/// and adds those mappings to the map.
Operation *clone(Operation &op, BlockAndValueMapping &mapper) {
Operation *cloneOp = op.clone(mapper);
block->getOperations().insert(insertPoint, cloneOp);
insert(cloneOp);
return cloneOp;
}
Operation *clone(Operation &op) {
Operation *cloneOp = op.clone();
block->getOperations().insert(insertPoint, cloneOp);
insert(cloneOp);
return cloneOp;
}
@@ -359,12 +358,12 @@ public:
/// updated to contain the results.
Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) {
Operation *cloneOp = op.cloneWithoutRegions(mapper);
block->getOperations().insert(insertPoint, cloneOp);
insert(cloneOp);
return cloneOp;
}
Operation *cloneWithoutRegions(Operation &op) {
Operation *cloneOp = op.cloneWithoutRegions();
block->getOperations().insert(insertPoint, cloneOp);
insert(cloneOp);
return cloneOp;
}
@@ -373,7 +372,9 @@ private:
/// 'results'.
void tryFold(Operation *op, SmallVectorImpl<Value *> &results);
Region *region;
/// Insert the given operation at the current insertion point.
void insert(Operation *op);
Block *block = nullptr;
Block::iterator insertPoint;
};

View File

@@ -87,7 +87,7 @@ private:
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
// Insert a body block that just returns the constant.
OpBuilder ob(result.getBody());
ob.createBlock();
ob.createBlock(&result.getBody());
auto sizeConstant = ob.create<LLVM::ConstantOp>(
loc, getIndexType(),
builder.getIntegerAttr(builder.getIndexType(), blob.getValue().size()));

View File

@@ -132,7 +132,7 @@ BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) {
auto *ib = currentB.getInsertionBlock();
auto ip = currentB.getInsertionPoint();
BlockHandle res;
res.block = ScopedContext::getBuilder().createBlock();
res.block = ScopedContext::getBuilder().createBlock(ib->getParent());
// createBlock sets the insertion point inside the block.
// We do not want this behavior when using declarative builders with nesting.
currentB.setInsertionPoint(ib, ip);

View File

@@ -335,29 +335,31 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
OpBuilder::~OpBuilder() {}
/// Add new block and set the insertion point to the end of it. If an
/// 'insertBefore' block is passed, the block will be placed before the
/// specified block. If not, the block will be appended to the end of the
/// current region.
Block *OpBuilder::createBlock(Block *insertBefore) {
/// Add new block and set the insertion point to the end of it. The block is
/// inserted at the provided insertion point of 'parent'.
Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
assert(parent && "expected valid parent region");
if (insertPt == Region::iterator())
insertPt = parent->end();
Block *b = new Block();
// If we are supposed to insert before a specific block, do so, otherwise add
// the block to the end of the region.
if (insertBefore)
region->getBlocks().insert(Region::iterator(insertBefore), b);
else
region->push_back(b);
parent->getBlocks().insert(insertPt, b);
setInsertionPointToEnd(b);
return b;
}
/// Add new block and set the insertion point to the end of it. The block is
/// placed before 'insertBefore'.
Block *OpBuilder::createBlock(Block *insertBefore) {
assert(insertBefore && "expected valid insertion block");
return createBlock(insertBefore->getParent(), Region::iterator(insertBefore));
}
/// Create an operation given the fields represented as an OperationState.
Operation *OpBuilder::createOperation(const OperationState &state) {
assert(block && "createOperation() called without setting builder's block");
auto *op = Operation::create(state);
block->getOperations().insert(insertPoint, op);
insert(op);
return op;
}
@@ -386,3 +388,9 @@ void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
[](OpFoldResult result) { return result.get<Value *>(); });
op->erase();
}
/// Insert the given operation at the current insertion point.
void OpBuilder::insert(Operation *op) {
if (block)
block->getOperations().insert(insertPoint, op);
}

View File

@@ -53,7 +53,7 @@ public:
/// Perform the rewrites. Return true if the rewrite converges in
/// `maxIterations`.
bool simplifyFunction(int maxIterations);
bool simplifyFunction(Region *region, int maxIterations);
void addToWorklist(Operation *op) {
// Check to see if the worklist already contains this op.
@@ -146,9 +146,8 @@ private:
} // end anonymous namespace
/// Perform the rewrites.
bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
Region *region = getRegion();
bool GreedyPatternRewriteDriver::simplifyFunction(Region *region,
int maxIterations) {
// Add the given operation to the worklist.
auto collectOps = [this](Operation *op) { addToWorklist(op); };
@@ -220,7 +219,8 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
bool mlir::applyPatternsGreedily(FuncOp fn,
OwningRewritePatternList &&patterns) {
GreedyPatternRewriteDriver driver(fn, std::move(patterns));
bool converged = driver.simplifyFunction(maxPatternMatchIterations);
bool converged =
driver.simplifyFunction(&fn.getBody(), maxPatternMatchIterations);
LLVM_DEBUG(if (!converged) {
llvm::dbgs()
<< "The pattern rewrite doesn't converge after scanning the function "