Use statement walker for constant folding.

- makes the code compact (gets rid of MLFunction walking logic)
- makes it natural to extend to fold affine map loop bounds
  and if conditions (upcoming CL)

PiperOrigin-RevId: 214668957
This commit is contained in:
Uday Bondhugula
2018-09-26 14:26:59 -07:00
committed by jpienaar
parent be8069eb33
commit 501462ac47

View File

@@ -18,20 +18,24 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
namespace {
/// Simple constant folding pass.
struct ConstantFold : public FunctionPass {
struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
// All constants in the function post folding.
SmallVector<SSAValue *, 8> existingConstants;
// Operation statements that were folded and that need to be erased.
std::vector<OperationStmt *> opStmtsToErase;
typedef std::function<SSAValue *(Attribute *, Type *)> ConstantFactoryType;
bool foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants,
ConstantFactoryType constantFactory);
void foldStmtBlock(StmtBlock &block,
SmallVectorImpl<SSAValue *> &existingConstants);
void visitOperationStmt(OperationStmt *stmt);
PassResult runOnCFGFunction(CFGFunction *f) override;
PassResult runOnMLFunction(MLFunction *f) override;
};
@@ -40,7 +44,7 @@ struct ConstantFold : public FunctionPass {
/// Attempt to fold the specified operation, updating the IR to match. If
/// constants are found, we keep track of them in the existingConstants list.
///
/// This returns false if the operation was successfully folded.
/// This returns 0 if the operation was successfully folded.
bool ConstantFold::foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants,
ConstantFactoryType constantFactory) {
@@ -49,7 +53,7 @@ bool ConstantFold::foldOperation(Operation *op,
// later, and don't try to fold it.
if (op->is<ConstantOp>()) {
existingConstants.push_back(op->getResult(0));
return true;
return 1;
}
// Check to see if each of the operands is a trivial constant. If so, get
@@ -95,7 +99,7 @@ bool ConstantFold::foldOperation(Operation *op,
// don't handle conditional control flow, constant PHI nodes, folding
// conditional branches, or anything else fancy.
PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
SmallVector<SSAValue *, 8> existingConstants;
existingConstants.clear();
CFGFuncBuilder builder(f);
for (auto &bb : *f) {
@@ -129,52 +133,30 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
return success();
}
void ConstantFold::foldStmtBlock(
StmtBlock &block, SmallVectorImpl<SSAValue *> &existingConstants) {
for (auto stmtIt = block.begin(), e = block.end(); stmtIt != e;) {
auto *stmt = &*stmtIt++;
// Fold the bodies of if and for statements.
// TODO: fold the conditions as well.
if (auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
foldStmtBlock(*ifStmt->getThen(), existingConstants);
if (auto *elseBlock = ifStmt->getElse())
foldStmtBlock(*elseBlock, existingConstants);
continue;
}
// TODO: Fold constant operands of mappings into the mapping itself.
if (auto *forStmt = dyn_cast<ForStmt>(stmt)) {
foldStmtBlock(*forStmt, existingConstants);
continue;
}
// Otherwise, if this is an operation stmt, try to fold it.
auto *opStmt = dyn_cast<OperationStmt>(stmt);
if (!opStmt)
continue;
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
MLFuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type)
->getResult();
};
if (!foldOperation(opStmt, existingConstants, constantFactory)) {
// At this point the operation is dead, remove it.
// TODO: This is assuming that all constant foldable operations have no
// side effects. When we have side effect modeling, we should verify that
// the operation is effect-free before we remove it. Until then this is
// close enough.
opStmt->eraseFromBlock();
}
// Override the walker's operation statement visit for constant folding.
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
MLFuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type)->getResult();
};
if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) {
opStmtsToErase.push_back(stmt);
}
}
PassResult ConstantFold::runOnMLFunction(MLFunction *f) {
SmallVector<SSAValue *, 8> existingConstants;
existingConstants.clear();
opStmtsToErase.clear();
foldStmtBlock(*f, existingConstants);
walk(f);
// At this point, these operations are dead, remove them.
// TODO: This is assuming that all constant foldable operations have no
// side effects. When we have side effect modeling, we should verify that
// the operation is effect-free before we remove it. Until then this is
// close enough.
for (auto *stmt : opStmtsToErase) {
stmt->eraseFromBlock();
}
// By the time we are done, we may have simplified a bunch of code, leaving
// around dead constants. Check for them now and remove them.