mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
Use statement walker for constant folding.
- makes the code compact (gets rid of MLFunction walking logic) - makes it natural to extend to fold affine map loop bounds and if conditions (upcoming CL) PiperOrigin-RevId: 214668957
This commit is contained in:
committed by
jpienaar
parent
be8069eb33
commit
501462ac47
@@ -18,20 +18,24 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/CFGFunction.h"
|
||||
#include "mlir/IR/StandardOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/Transforms/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Simple constant folding pass.
|
||||
struct ConstantFold : public FunctionPass {
|
||||
struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
|
||||
// All constants in the function post folding.
|
||||
SmallVector<SSAValue *, 8> existingConstants;
|
||||
// Operation statements that were folded and that need to be erased.
|
||||
std::vector<OperationStmt *> opStmtsToErase;
|
||||
typedef std::function<SSAValue *(Attribute *, Type *)> ConstantFactoryType;
|
||||
|
||||
bool foldOperation(Operation *op,
|
||||
SmallVectorImpl<SSAValue *> &existingConstants,
|
||||
ConstantFactoryType constantFactory);
|
||||
void foldStmtBlock(StmtBlock &block,
|
||||
SmallVectorImpl<SSAValue *> &existingConstants);
|
||||
void visitOperationStmt(OperationStmt *stmt);
|
||||
PassResult runOnCFGFunction(CFGFunction *f) override;
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
};
|
||||
@@ -40,7 +44,7 @@ struct ConstantFold : public FunctionPass {
|
||||
/// Attempt to fold the specified operation, updating the IR to match. If
|
||||
/// constants are found, we keep track of them in the existingConstants list.
|
||||
///
|
||||
/// This returns false if the operation was successfully folded.
|
||||
/// This returns 0 if the operation was successfully folded.
|
||||
bool ConstantFold::foldOperation(Operation *op,
|
||||
SmallVectorImpl<SSAValue *> &existingConstants,
|
||||
ConstantFactoryType constantFactory) {
|
||||
@@ -49,7 +53,7 @@ bool ConstantFold::foldOperation(Operation *op,
|
||||
// later, and don't try to fold it.
|
||||
if (op->is<ConstantOp>()) {
|
||||
existingConstants.push_back(op->getResult(0));
|
||||
return true;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Check to see if each of the operands is a trivial constant. If so, get
|
||||
@@ -95,7 +99,7 @@ bool ConstantFold::foldOperation(Operation *op,
|
||||
// don't handle conditional control flow, constant PHI nodes, folding
|
||||
// conditional branches, or anything else fancy.
|
||||
PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
||||
SmallVector<SSAValue *, 8> existingConstants;
|
||||
existingConstants.clear();
|
||||
CFGFuncBuilder builder(f);
|
||||
|
||||
for (auto &bb : *f) {
|
||||
@@ -129,52 +133,30 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void ConstantFold::foldStmtBlock(
|
||||
StmtBlock &block, SmallVectorImpl<SSAValue *> &existingConstants) {
|
||||
for (auto stmtIt = block.begin(), e = block.end(); stmtIt != e;) {
|
||||
auto *stmt = &*stmtIt++;
|
||||
|
||||
// Fold the bodies of if and for statements.
|
||||
// TODO: fold the conditions as well.
|
||||
if (auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
|
||||
foldStmtBlock(*ifStmt->getThen(), existingConstants);
|
||||
if (auto *elseBlock = ifStmt->getElse())
|
||||
foldStmtBlock(*elseBlock, existingConstants);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: Fold constant operands of mappings into the mapping itself.
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(stmt)) {
|
||||
foldStmtBlock(*forStmt, existingConstants);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, if this is an operation stmt, try to fold it.
|
||||
auto *opStmt = dyn_cast<OperationStmt>(stmt);
|
||||
if (!opStmt)
|
||||
continue;
|
||||
|
||||
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
|
||||
MLFuncBuilder builder(stmt);
|
||||
return builder.create<ConstantOp>(stmt->getLoc(), value, type)
|
||||
->getResult();
|
||||
};
|
||||
|
||||
if (!foldOperation(opStmt, existingConstants, constantFactory)) {
|
||||
// At this point the operation is dead, remove it.
|
||||
// TODO: This is assuming that all constant foldable operations have no
|
||||
// side effects. When we have side effect modeling, we should verify that
|
||||
// the operation is effect-free before we remove it. Until then this is
|
||||
// close enough.
|
||||
opStmt->eraseFromBlock();
|
||||
}
|
||||
// Override the walker's operation statement visit for constant folding.
|
||||
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
||||
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
|
||||
MLFuncBuilder builder(stmt);
|
||||
return builder.create<ConstantOp>(stmt->getLoc(), value, type)->getResult();
|
||||
};
|
||||
if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) {
|
||||
opStmtsToErase.push_back(stmt);
|
||||
}
|
||||
}
|
||||
|
||||
PassResult ConstantFold::runOnMLFunction(MLFunction *f) {
|
||||
SmallVector<SSAValue *, 8> existingConstants;
|
||||
existingConstants.clear();
|
||||
opStmtsToErase.clear();
|
||||
|
||||
foldStmtBlock(*f, existingConstants);
|
||||
walk(f);
|
||||
// At this point, these operations are dead, remove them.
|
||||
// TODO: This is assuming that all constant foldable operations have no
|
||||
// side effects. When we have side effect modeling, we should verify that
|
||||
// the operation is effect-free before we remove it. Until then this is
|
||||
// close enough.
|
||||
for (auto *stmt : opStmtsToErase) {
|
||||
stmt->eraseFromBlock();
|
||||
}
|
||||
|
||||
// By the time we are done, we may have simplified a bunch of code, leaving
|
||||
// around dead constants. Check for them now and remove them.
|
||||
|
||||
Reference in New Issue
Block a user