diff --git a/mlir/lib/Transforms/ConstantFold.cpp b/mlir/lib/Transforms/ConstantFold.cpp index df2d72f0b39c..996ff3862333 100644 --- a/mlir/lib/Transforms/ConstantFold.cpp +++ b/mlir/lib/Transforms/ConstantFold.cpp @@ -18,20 +18,24 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/CFGFunction.h" #include "mlir/IR/StandardOps.h" +#include "mlir/IR/StmtVisitor.h" #include "mlir/Transforms/Pass.h" #include "mlir/Transforms/Passes.h" using namespace mlir; namespace { /// Simple constant folding pass. -struct ConstantFold : public FunctionPass { +struct ConstantFold : public FunctionPass, StmtWalker { + // All constants in the function post folding. + SmallVector existingConstants; + // Operation statements that were folded and that need to be erased. + std::vector opStmtsToErase; typedef std::function ConstantFactoryType; bool foldOperation(Operation *op, SmallVectorImpl &existingConstants, ConstantFactoryType constantFactory); - void foldStmtBlock(StmtBlock &block, - SmallVectorImpl &existingConstants); + void visitOperationStmt(OperationStmt *stmt); PassResult runOnCFGFunction(CFGFunction *f) override; PassResult runOnMLFunction(MLFunction *f) override; }; @@ -40,7 +44,7 @@ struct ConstantFold : public FunctionPass { /// Attempt to fold the specified operation, updating the IR to match. If /// constants are found, we keep track of them in the existingConstants list. /// -/// This returns false if the operation was successfully folded. +/// This returns 0 if the operation was successfully folded. bool ConstantFold::foldOperation(Operation *op, SmallVectorImpl &existingConstants, ConstantFactoryType constantFactory) { @@ -49,7 +53,7 @@ bool ConstantFold::foldOperation(Operation *op, // later, and don't try to fold it. if (op->is()) { existingConstants.push_back(op->getResult(0)); - return true; + return 1; } // Check to see if each of the operands is a trivial constant. If so, get @@ -95,7 +99,7 @@ bool ConstantFold::foldOperation(Operation *op, // don't handle conditional control flow, constant PHI nodes, folding // conditional branches, or anything else fancy. PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { - SmallVector existingConstants; + existingConstants.clear(); CFGFuncBuilder builder(f); for (auto &bb : *f) { @@ -129,52 +133,30 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) { return success(); } -void ConstantFold::foldStmtBlock( - StmtBlock &block, SmallVectorImpl &existingConstants) { - for (auto stmtIt = block.begin(), e = block.end(); stmtIt != e;) { - auto *stmt = &*stmtIt++; - - // Fold the bodies of if and for statements. - // TODO: fold the conditions as well. - if (auto *ifStmt = dyn_cast(stmt)) { - foldStmtBlock(*ifStmt->getThen(), existingConstants); - if (auto *elseBlock = ifStmt->getElse()) - foldStmtBlock(*elseBlock, existingConstants); - continue; - } - - // TODO: Fold constant operands of mappings into the mapping itself. - if (auto *forStmt = dyn_cast(stmt)) { - foldStmtBlock(*forStmt, existingConstants); - continue; - } - - // Otherwise, if this is an operation stmt, try to fold it. - auto *opStmt = dyn_cast(stmt); - if (!opStmt) - continue; - - auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * { - MLFuncBuilder builder(stmt); - return builder.create(stmt->getLoc(), value, type) - ->getResult(); - }; - - if (!foldOperation(opStmt, existingConstants, constantFactory)) { - // At this point the operation is dead, remove it. - // TODO: This is assuming that all constant foldable operations have no - // side effects. When we have side effect modeling, we should verify that - // the operation is effect-free before we remove it. Until then this is - // close enough. - opStmt->eraseFromBlock(); - } +// Override the walker's operation statement visit for constant folding. +void ConstantFold::visitOperationStmt(OperationStmt *stmt) { + auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * { + MLFuncBuilder builder(stmt); + return builder.create(stmt->getLoc(), value, type)->getResult(); + }; + if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) { + opStmtsToErase.push_back(stmt); } } PassResult ConstantFold::runOnMLFunction(MLFunction *f) { - SmallVector existingConstants; + existingConstants.clear(); + opStmtsToErase.clear(); - foldStmtBlock(*f, existingConstants); + walk(f); + // At this point, these operations are dead, remove them. + // TODO: This is assuming that all constant foldable operations have no + // side effects. When we have side effect modeling, we should verify that + // the operation is effect-free before we remove it. Until then this is + // close enough. + for (auto *stmt : opStmtsToErase) { + stmt->eraseFromBlock(); + } // By the time we are done, we may have simplified a bunch of code, leaving // around dead constants. Check for them now and remove them.