Files
llvm/mlir/lib/Transforms/TestConstantFold.cpp
River Riddle a8f4b9eeeb Iterate on the operations to fold in TestConstantFold in reverse to remove the need for ConstantFoldHelper to have a flag for insertion at the head of the entry block. This also fixes an asan bug in TestConstantFold due to the iteration order of operations and ConstantFoldHelper's constant insertion placement.
Note: This now means that we cannot fold chains of operations, i.e. where constant foldable operations feed into each other. Given that this is a testing pass solely for constant folding, this isn't really something that we want anyways. Constant fold tests should be simple and direct, with more advanced folding/feeding being tested with the canonicalizer.

--

PiperOrigin-RevId: 242011744
2019-04-05 07:41:52 -07:00

98 lines
3.4 KiB
C++

//===- TestConstantFold.cpp - Pass to test constant folding ---------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/ConstantFoldUtils.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
using namespace mlir;
namespace {
/// Simple constant folding pass.
struct TestConstantFold : public FunctionPass<TestConstantFold> {
// All constants in the function post folding.
SmallVector<Operation *, 8> existingConstants;
// Operations that were folded and that need to be erased.
std::vector<Operation *> opsToErase;
void foldOperation(Operation *op, ConstantFoldHelper &helper);
void runOnFunction() override;
};
} // end anonymous namespace
void TestConstantFold::foldOperation(Operation *op,
ConstantFoldHelper &helper) {
// Attempt to fold the specified operation, including handling unused or
// duplicated constants.
if (helper.tryToConstantFold(op)) {
opsToErase.push_back(op);
}
// If this op is a constant that are used and cannot be de-duplicated,
// remember it for cleanup later.
else if (auto constant = op->dyn_cast<ConstantOp>()) {
existingConstants.push_back(op);
}
}
// For now, we do a simple top-down pass over a function folding constants. We
// don't handle conditional control flow, block arguments, folding conditional
// branches, or anything else fancy.
void TestConstantFold::runOnFunction() {
existingConstants.clear();
opsToErase.clear();
auto &f = getFunction();
ConstantFoldHelper helper(&f);
// Collect and fold the operations within the function.
SmallVector<Operation *, 8> ops;
f.walk([&](Operation *op) { ops.push_back(op); });
// Fold the constants in reverse so that the last generated constants from
// folding are at the beginning. This creates somewhat of a linear ordering to
// the newly generated constants that matches the operation order and improves
// the readability of test cases.
for (Operation *op : llvm::reverse(ops))
foldOperation(op, helper);
// At this point, these operations are dead, remove them.
for (auto *op : opsToErase) {
assert(op->hasNoSideEffect() && "Constant folded op with side effects?");
op->erase();
}
// By the time we are done, we may have simplified a bunch of code, leaving
// around dead constants. Check for them now and remove them.
for (auto *cst : existingConstants) {
if (cst->use_empty())
cst->erase();
}
}
/// Creates a constant folding pass.
FunctionPassBase *mlir::createTestConstantFoldPass() {
return new TestConstantFold();
}
static PassRegistration<TestConstantFold>
pass("test-constant-fold", "Test operation constant folding");