Files
llvm/mlir/lib/Transforms/RemoveDeadValues.cpp
Matthias Springer e6110cb339 [mlir][Transforms] Fix crash in -remove-dead-values on private functions (#169269)
This commit fixes two crashes in the `-remove-dead-values` pass related
to private functions.

Private functions are considered entirely "dead" by the liveness
analysis, which drives the `-remove-dead-values` pass.

The `-remove-dead-values` pass removes dead block arguments from private
functions. Private functions are entirely dead, so all of their block
arguments are removed. However, the pass did not correctly update all
users of these dropped block arguments.

1. A side-effecting operation must be removed if one of its operands is
dead. Otherwise, the operation would end up with a NULL operand. Note:
The liveness analysis would not have marked an SSA value as "dead" if it
had a reachable side-effecting users. (Therefore, it is safe to erase
such side-effecting operations.)
2. A branch operation must be removed if one of its non-forwarded
operands is dead. (E.g., the condition value of a `cf.cond_br`.)
Whenever a terminator is removed, a `ub.unrechable` operation is
inserted. This fixes #158760.
2025-12-03 08:35:05 +01:00

962 lines
40 KiB
C++

//===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The goal of this pass is optimization (reducing runtime) by removing
// unnecessary instructions. Unlike other passes that rely on local information
// gathered from patterns to accomplish optimization, this pass uses a full
// analysis of the IR, specifically, liveness analysis, and is thus more
// powerful.
//
// Currently, this pass performs the following optimizations:
// (A) Removes function arguments that are not live,
// (B) Removes function return values that are not live across all callers of
// the function,
// (C) Removes unneccesary operands, results, region arguments, and region
// terminator operands of region branch ops, and,
// (D) Removes simple and region branch ops that have all non-live results and
// don't affect memory in any way,
//
// iff
//
// the IR doesn't have any non-function symbol ops, non-call symbol user ops and
// branch ops.
//
// Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
// region branch op, branch op, region branch terminator op, or return-like.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <cstddef>
#include <memory>
#include <optional>
#include <vector>
#define DEBUG_TYPE "remove-dead-values"
namespace mlir {
#define GEN_PASS_DEF_REMOVEDEADVALUES
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::dataflow;
//===----------------------------------------------------------------------===//
// RemoveDeadValues Pass
//===----------------------------------------------------------------------===//
namespace {
// Set of structures below to be filled with operations and arguments to erase.
// This is done to separate analysis and tree modification phases,
// otherwise analysis is operating on half-deleted tree which is incorrect.
struct FunctionToCleanUp {
FunctionOpInterface funcOp;
BitVector nonLiveArgs;
BitVector nonLiveRets;
};
struct OperationToCleanup {
Operation *op;
BitVector nonLive;
Operation *callee =
nullptr; // Optional: For CallOpInterface ops, stores the callee function
};
struct BlockArgsToCleanup {
Block *b;
BitVector nonLiveArgs;
};
struct SuccessorOperandsToCleanup {
BranchOpInterface branch;
unsigned successorIndex;
BitVector nonLiveOperands;
};
struct RDVFinalCleanupList {
SmallVector<Operation *> operations;
SmallVector<Value> values;
SmallVector<FunctionToCleanUp> functions;
SmallVector<OperationToCleanup> operands;
SmallVector<OperationToCleanup> results;
SmallVector<BlockArgsToCleanup> blocks;
SmallVector<SuccessorOperandsToCleanup> successorOperands;
};
// Some helper functions...
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
RunLivenessAnalysis &la) {
for (Value value : values) {
if (nonLiveSet.contains(value)) {
LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
}
const Liveness *liveness = la.getLiveness(value);
if (!liveness) {
LDBG() << "Value " << value
<< " has no liveness info, conservatively considered live";
return true;
}
if (liveness->isLive) {
LDBG() << "Value " << value << " is live according to liveness analysis";
return true;
} else {
LDBG() << "Value " << value << " is dead according to liveness analysis";
}
}
return false;
}
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);
for (auto [index, value] : llvm::enumerate(values)) {
if (nonLiveSet.contains(value)) {
lives.reset(index);
LDBG() << "Value " << value
<< " is already marked non-live (dead) at index " << index;
continue;
}
const Liveness *liveness = la.getLiveness(value);
// It is important to note that when `liveness` is null, we can't tell if
// `value` is live or not. So, the safe option is to consider it live. Also,
// the execution of this pass might create new SSA values when erasing some
// of the results of an op and we know that these new values are live
// (because they weren't erased) and also their liveness is null because
// liveness analysis ran before their creation.
if (!liveness) {
LDBG() << "Value " << value << " at index " << index
<< " has no liveness info, conservatively considered live";
continue;
}
if (!liveness->isLive) {
lives.reset(index);
LDBG() << "Value " << value << " at index " << index
<< " is dead according to liveness analysis";
} else {
LDBG() << "Value " << value << " at index " << index
<< " is live according to liveness analysis";
}
}
return lives;
}
/// Collects values marked as "non-live" in the provided range and inserts them
/// into the nonLiveSet. A value is considered "non-live" if the corresponding
/// index in the `nonLive` bit vector is set.
static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
const BitVector &nonLive) {
for (auto [index, result] : llvm::enumerate(range)) {
if (!nonLive[index])
continue;
nonLiveSet.insert(result);
LDBG() << "Marking value " << result << " as non-live (dead) at index "
<< index;
}
}
/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
/// is 1.
static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
assert(op->getNumResults() == toErase.size() &&
"expected the number of results in `op` and the size of `toErase` to "
"be the same");
std::vector<Type> newResultTypes;
for (OpResult result : op->getResults())
if (!toErase[result.getResultNumber()])
newResultTypes.push_back(result.getType());
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
OperationState state(op->getLoc(), op->getName().getStringRef(),
op->getOperands(), newResultTypes, op->getAttrs());
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
state.addRegion();
Operation *newOp = builder.create(state);
for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
Region &newRegion = newOp->getRegion(index);
// Move all blocks of `region` into `newRegion`.
Block *temp = new Block();
newRegion.push_back(temp);
while (!region.empty())
region.front().moveBefore(temp);
temp->erase();
}
unsigned indexOfNextNewCallOpResultToReplace = 0;
for (auto [index, result] : llvm::enumerate(op->getResults())) {
assert(result && "expected result to be non-null");
if (toErase[index]) {
result.dropAllUses();
} else {
result.replaceAllUsesWith(
newOp->getResult(indexOfNextNewCallOpResultToReplace++));
}
}
op->erase();
}
/// Convert a list of `Operand`s to a list of `OpOperand`s.
static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
OpOperand *values = operands.getBase();
SmallVector<OpOperand *> opOperands;
for (unsigned i = 0, e = operands.size(); i < e; i++)
opOperands.push_back(&values[i]);
return opOperands;
}
/// Process a simple operation `op` using the liveness analysis `la`.
/// If the operation has no memory effects and none of its results are live:
/// 1. Add the operation to a list for future removal, and
/// 2. Mark all its results as non-live values
///
/// The operation `op` is assumed to be simple. A simple operation is one that
/// is NOT:
/// - Function-like
/// - Call-like
/// - A region branch operation
/// - A branch operation
/// - A region branch terminator
/// - Return-like
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
// Operations that have dead operands can be erased regardless of their
// side effects. The liveness analysis would not have marked an SSA value as
// "dead" if it had a side-effecting user that is reachable.
bool hasDeadOperand =
markLives(op->getOperands(), nonLiveSet, la).flip().any();
if (hasDeadOperand) {
LDBG() << "Simple op has dead operands, so the op must be dead: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
assert(!hasLive(op->getResults(), nonLiveSet, la) &&
"expected the op to have no live results");
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
return;
}
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
return;
}
LDBG()
<< "Simple op has all dead results and is memory effect free, scheduling "
"for removal: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
}
/// Process a function-like operation `funcOp` using the liveness analysis `la`
/// and the IR in `module`. If it is not public or external:
/// (1) Adding its non-live arguments to a list for future removal.
/// (2) Marking their corresponding operands in its callers for removal.
/// (3) Identifying and enqueueing unnecessary terminator operands
/// (return values that are non-live across all callers) for removal.
/// (4) Enqueueing the non-live arguments and return values for removal.
/// (5) Collecting the uses of these return values in its callers for future
/// removal.
/// (6) Marking all its results as non-live values.
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing function op: "
<< OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
if (funcOp.isPublic() || funcOp.isExternal()) {
LDBG() << "Function is public or external, skipping: "
<< funcOp.getOperation()->getName();
return;
}
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
nonLiveArgs = nonLiveArgs.flip();
// Do (1).
for (auto [index, arg] : llvm::enumerate(arguments))
if (arg && nonLiveArgs[index]) {
cl.values.push_back(arg);
nonLiveSet.insert(arg);
}
// Do (2). (Skip creating generic operand cleanup entries for call ops.
// Call arguments will be removed in the call-site specific segment-aware
// cleanup, avoiding generic eraseOperands bitvector mechanics.)
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
// Push an empty operand cleanup entry so that call-site specific logic in
// cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
// intentionally all false to avoid generic erasure.
// Store the funcOp as the callee to avoid expensive symbol lookup later.
cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false),
funcOp.getOperation()});
}
// Do (3).
// Get the list of unnecessary terminator operands (return values that are
// non-live across all callers) in `nonLiveRets`. There is a very important
// subtlety here. Unnecessary terminator operands are NOT the operands of the
// terminator that are non-live. Instead, these are the return values of the
// callers such that a given return value is non-live across all callers. Such
// corresponding operands in the terminator could be live. An example to
// demonstrate this:
// func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
// %c0_i32 = arith.constant 0 : i32
// %0 = arith.addi %c0_i32, %c0_i32 : i32
// memref.store %0, %arg0[] : memref<i32>
// return %c0_i32, %0 : i32, i32
// }
// func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
// %1:2 = call @f(%arg1) : (memref<i32>) -> i32
// return %1#0 : i32
// }
// Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
// need to return %0. But, %0 is live. And, still, we want to stop it from
// being returned, in order to optimize our IR. So, this demonstrates how we
// can make our optimization strong by even removing a live return value (%0),
// since it forwards only to non-live value(s) (%1#1).
size_t numReturns = funcOp.getNumResults();
BitVector nonLiveRets(numReturns, true);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
nonLiveRets &= liveCallRets.flip();
}
// Note that in the absence of control flow ops forcing the control to go from
// the entry (first) block to the other blocks, the control never reaches any
// block other than the entry block, because every block has a terminator.
for (Block &block : funcOp.getBlocks()) {
Operation *returnOp = block.getTerminator();
if (!returnOp->hasTrait<OpTrait::ReturnLike>())
continue;
if (returnOp && returnOp->getNumOperands() == numReturns)
cl.operands.push_back({returnOp, nonLiveRets});
}
// Do (4).
cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
// Do (5) and (6).
if (numReturns == 0)
return;
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
cl.results.push_back({callOp, nonLiveRets});
collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
}
}
/// Process a region branch operation `regionBranchOp` using the liveness
/// information in `la`. The processing involves two scenarios:
///
/// Scenario 1: If the operation has no memory effects and none of its results
/// are live:
/// (1') Enqueue all its uses for deletion.
/// (2') Enqueue the branch itself for deletion.
///
/// Scenario 2: Otherwise:
/// (1) Collect its unnecessary operands (operands forwarded to unnecessary
/// results or arguments).
/// (2) Process each of its regions.
/// (3) Collect the uses of its unnecessary results (results forwarded from
/// unnecessary operands
/// or terminator operands).
/// (4) Add these results to the deletion list.
///
/// Processing a region includes:
/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
/// from unnecessary operands
/// or terminator operands).
/// (b) Collecting these unnecessary arguments.
/// (c) Collecting its unnecessary terminator operands (terminator operands
/// forwarded to unnecessary results
/// or arguments).
///
/// Value Flow Note: In this operation, values flow as follows:
/// - From operands and terminator operands (successor operands)
/// - To arguments and results (successor inputs).
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing region branch op: "
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
};
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
SmallVector<Value> arguments(region.front().getArguments());
BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
liveArgs[&region] = regionLiveArgs;
}
};
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
auto getSuccessors = [&](RegionBranchPoint point) {
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
return successors;
};
// Return the operands of `terminator` that are forwarded to `successor` if
// the former is not null. Else return the operands of `regionBranchOp`
// forwarded to `successor`.
auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
Operation *terminator = nullptr) {
OperandRange operands =
terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
.getSuccessorOperands(successor)
: regionBranchOp.getEntrySuccessorOperands(successor);
SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
return opOperands;
};
// Mark the non-forwarded operands of `regionBranchOp` in
// `nonForwardedOperands`.
auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
for (const RegionSuccessor &successor :
getSuccessors(RegionBranchPoint::parent())) {
for (OpOperand *opOperand : getForwardedOpOperands(successor))
nonForwardedOperands.reset(opOperand->getOperandNumber());
}
};
// Mark the non-forwarded terminator operands of the various regions of
// `regionBranchOp` in `nonForwardedRets`.
auto markNonForwardedReturnValues =
[&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
// TODO: this isn't correct in face of multiple terminators.
Operation *terminator = region.front().getTerminator();
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
for (const RegionSuccessor &successor :
getSuccessors(RegionBranchPoint(
cast<RegionBranchTerminatorOpInterface>(terminator)))) {
for (OpOperand *opOperand :
getForwardedOpOperands(successor, terminator))
nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
}
}
};
// Update `valuesToKeep` (which is expected to correspond to operands or
// terminator operands) based on `resultsToKeep` and `argsToKeep`, given
// `region`. When `valuesToKeep` correspond to operands, `region` is null.
// Else, `region` is the parent region of the terminator.
auto updateOperandsOrTerminatorOperandsToKeep =
[&](BitVector &valuesToKeep, BitVector &resultsToKeep,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
RegionBranchPoint point =
terminator
? RegionBranchPoint(
cast<RegionBranchTerminatorOpInterface>(terminator))
: RegionBranchPoint::parent();
for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
successor.getSuccessorInputs())) {
size_t operandNum = opOperand->getOperandNumber();
bool updateBasedOn =
successorRegion
? argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()]
: resultsToKeep[cast<OpResult>(input).getResultNumber()];
valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
}
}
};
// Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
// `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
// value is modified, else, false.
auto recomputeResultsAndArgsToKeep =
[&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
BitVector &operandsToKeep,
DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
bool &resultsOrArgsToKeepChanged) {
resultsOrArgsToKeepChanged = false;
// Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
for (const RegionSuccessor &successor :
getSuccessors(RegionBranchPoint::parent())) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor),
successor.getSuccessorInputs())) {
bool recomputeBasedOn =
operandsToKeep[opOperand->getOperandNumber()];
bool toRecompute =
successorRegion
? argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()]
: resultsToKeep[cast<OpResult>(input).getResultNumber()];
if (!toRecompute && recomputeBasedOn)
resultsOrArgsToKeepChanged = true;
if (successorRegion) {
argsToKeep[successorRegion][cast<BlockArgument>(input)
.getArgNumber()] =
argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()] |
recomputeBasedOn;
} else {
resultsToKeep[cast<OpResult>(input).getResultNumber()] =
resultsToKeep[cast<OpResult>(input).getResultNumber()] |
recomputeBasedOn;
}
}
}
// Recompute `resultsToKeep` and `argsToKeep` based on
// `terminatorOperandsToKeep`.
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
for (const RegionSuccessor &successor :
getSuccessors(RegionBranchPoint(
cast<RegionBranchTerminatorOpInterface>(terminator)))) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
successor.getSuccessorInputs())) {
bool recomputeBasedOn =
terminatorOperandsToKeep[region.back().getTerminator()]
[opOperand->getOperandNumber()];
bool toRecompute =
successorRegion
? argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()]
: resultsToKeep[cast<OpResult>(input).getResultNumber()];
if (!toRecompute && recomputeBasedOn)
resultsOrArgsToKeepChanged = true;
if (successorRegion) {
argsToKeep[successorRegion][cast<BlockArgument>(input)
.getArgNumber()] =
argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()] |
recomputeBasedOn;
} else {
resultsToKeep[cast<OpResult>(input).getResultNumber()] =
resultsToKeep[cast<OpResult>(input).getResultNumber()] |
recomputeBasedOn;
}
}
}
}
};
// Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
// `operandsToKeep`, and `terminatorOperandsToKeep`.
auto markValuesToKeep =
[&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
BitVector &operandsToKeep,
DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
bool resultsOrArgsToKeepChanged = true;
// We keep updating and recomputing the values until we reach a point
// where they stop changing.
while (resultsOrArgsToKeepChanged) {
// Update the operands that need to be kept.
updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
resultsToKeep, argsToKeep);
// Update the terminator operands that need to be kept.
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
updateOperandsOrTerminatorOperandsToKeep(
terminatorOperandsToKeep[region.back().getTerminator()],
resultsToKeep, argsToKeep, &region);
}
// Recompute the results and arguments that need to be kept.
recomputeResultsAndArgsToKeep(
resultsToKeep, argsToKeep, operandsToKeep,
terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
}
};
// Scenario 1. This is the only case where the entire `regionBranchOp`
// is removed. It will not happen in any other scenario. Note that in this
// case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
// It could never be live because of this op but its liveness could have been
// attributed to something else.
// Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
!hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
// Scenario 2.
// At this point, we know that every non-forwarded operand of `regionBranchOp`
// is live.
// Stores the results of `regionBranchOp` that we want to keep.
BitVector resultsToKeep;
// Stores the mapping from regions of `regionBranchOp` to their arguments that
// we want to keep.
DenseMap<Region *, BitVector> argsToKeep;
// Stores the operands of `regionBranchOp` that we want to keep.
BitVector operandsToKeep;
// Stores the mapping from region terminators in `regionBranchOp` to their
// operands that we want to keep.
DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
// Initializing the above variables...
// The live results of `regionBranchOp` definitely need to be kept.
markLiveResults(resultsToKeep);
// Similarly, the live arguments of the regions in `regionBranchOp` definitely
// need to be kept.
markLiveArgs(argsToKeep);
// The non-forwarded operands of `regionBranchOp` definitely need to be kept.
// A live forwarded operand can be removed but no non-forwarded operand can be
// removed since it "controls" the flow of data in this control flow op.
markNonForwardedOperands(operandsToKeep);
// Similarly, the non-forwarded terminator operands of the regions in
// `regionBranchOp` definitely need to be kept.
markNonForwardedReturnValues(terminatorOperandsToKeep);
// Mark the values (results, arguments, operands, and terminator operands)
// that we want to keep.
markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
terminatorOperandsToKeep);
// Do (1).
cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
// Do (2.a) and (2.b).
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
BitVector argsToRemove = argsToKeep[&region].flip();
cl.blocks.push_back({&region.front(), argsToRemove});
collectNonLiveValues(nonLiveSet, region.front().getArguments(),
argsToRemove);
}
// Do (2.c).
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
cl.operands.push_back(
{terminator, terminatorOperandsToKeep[terminator].flip()});
}
// Do (3) and (4).
BitVector resultsToRemove = resultsToKeep.flip();
collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
resultsToRemove);
cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
}
/// Steps to process a `BranchOpInterface` operation:
///
/// When a non-forwarded operand is dead (e.g., the condition value of a
/// conditional branch op), the entire operation is dead.
///
/// Otherwise, iterate through each successor block of `branchOp`.
/// (1) For each successor block, gather all operands from all successors.
/// (2) Fetch their associated liveness analysis data and collect for future
/// removal.
/// (3) Identify and collect the dead operands from the successor block
/// as well as their corresponding arguments.
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
// Check for dead non-forwarded operands.
BitVector deadNonForwardedOperands =
markLives(branchOp->getOperands(), nonLiveSet, la).flip();
unsigned numSuccessors = branchOp->getNumSuccessors();
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
SuccessorOperands successorOperands =
branchOp.getSuccessorOperands(succIdx);
// Remove all non-forwarded operands from the bit vector.
for (OpOperand &opOperand : successorOperands.getMutableForwardedOperands())
deadNonForwardedOperands[opOperand.getOperandNumber()] = false;
}
if (deadNonForwardedOperands.any()) {
cl.operations.push_back(branchOp.getOperation());
return;
}
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
Block *successorBlock = branchOp->getSuccessor(succIdx);
// Do (1)
SuccessorOperands successorOperands =
branchOp.getSuccessorOperands(succIdx);
SmallVector<Value> operandValues;
for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
++operandIdx) {
operandValues.push_back(successorOperands[operandIdx]);
}
// Do (2)
BitVector successorNonLive =
markLives(operandValues, nonLiveSet, la).flip();
collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
successorNonLive);
// Do (3)
cl.blocks.push_back({successorBlock, successorNonLive});
cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
}
}
/// Removes dead values collected in RDVFinalCleanupList.
/// To be run once when all dead values have been collected.
static void cleanUpDeadVals(RDVFinalCleanupList &list) {
LDBG() << "Starting cleanup of dead values...";
// 1. Blocks, We must remove the block arguments and successor operands before
// deleting the operation, as they may reside in the region operation.
LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
for (auto &b : list.blocks) {
// blocks that are accessed via multiple codepaths processed once
if (b.b->getNumArguments() != b.nonLiveArgs.size())
continue;
LDBG() << "Erasing " << b.nonLiveArgs.count()
<< " non-live arguments from block: " << b.b;
// it iterates backwards because erase invalidates all successor indexes
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
}
// 2. Successor Operands
LDBG() << "Cleaning up " << list.successorOperands.size()
<< " successor operand lists";
for (auto &op : list.successorOperands) {
SuccessorOperands successorOperands =
op.branch.getSuccessorOperands(op.successorIndex);
// blocks that are accessed via multiple codepaths processed once
if (successorOperands.size() != op.nonLiveOperands.size())
continue;
LDBG() << "Erasing " << op.nonLiveOperands.count()
<< " non-live successor operands from successor "
<< op.successorIndex << " of branch: "
<< OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
// it iterates backwards because erase invalidates all successor indexes
for (int i = successorOperands.size() - 1; i >= 0; --i) {
if (!op.nonLiveOperands[i])
continue;
LDBG() << " Erasing successor operand " << i << ": "
<< successorOperands[i];
successorOperands.erase(i);
}
}
// 3. Operations
LDBG() << "Cleaning up " << list.operations.size() << " operations";
for (Operation *op : list.operations) {
LDBG() << "Erasing operation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
if (op->hasTrait<OpTrait::IsTerminator>()) {
// When erasing a terminator, insert an unreachable op in its place.
OpBuilder b(op);
ub::UnreachableOp::create(b, op->getLoc());
}
op->dropAllUses();
op->erase();
}
// 4. Values
LDBG() << "Cleaning up " << list.values.size() << " values";
for (auto &v : list.values) {
LDBG() << "Dropping all uses of value: " << v;
v.dropAllUses();
}
// 5. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
// Record which function arguments were erased so we can shrink call-site
// argument segments for CallOpInterface operations (e.g. ops using
// AttrSizedOperandSegments) in the next phase.
DenseMap<Operation *, BitVector> erasedFuncArgs;
for (auto &f : list.functions) {
LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
LDBG() << " Erasing " << f.nonLiveRets.count()
<< " non-live return values";
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
// Record only if we actually erased something.
if (f.nonLiveArgs.any())
erasedFuncArgs.try_emplace(f.funcOp.getOperation(), f.nonLiveArgs);
}
(void)f.funcOp.eraseResults(f.nonLiveRets);
}
// 6. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
// Handle call-specific cleanup only when we have a cached callee reference.
// This avoids expensive symbol lookup and is defensive against future
// changes.
bool handledAsCall = false;
if (o.callee && isa<CallOpInterface>(o.op)) {
auto call = cast<CallOpInterface>(o.op);
auto it = erasedFuncArgs.find(o.callee);
if (it != erasedFuncArgs.end()) {
const BitVector &deadArgIdxs = it->second;
MutableOperandRange args = call.getArgOperandsMutable();
// First, erase the call arguments corresponding to erased callee
// args. We iterate backwards to preserve indices.
for (unsigned argIdx : llvm::reverse(deadArgIdxs.set_bits()))
args.erase(argIdx);
// If this operand cleanup entry also has a generic nonLive bitvector,
// clear bits for call arguments we already erased above to avoid
// double-erasing (which could impact other segments of ops with
// AttrSizedOperandSegments).
if (o.nonLive.any()) {
// Map the argument logical index to the operand number(s) recorded.
int operandOffset = call.getArgOperands().getBeginOperandIndex();
for (int argIdx : deadArgIdxs.set_bits()) {
int operandNumber = operandOffset + argIdx;
if (operandNumber < static_cast<int>(o.nonLive.size()))
o.nonLive.reset(operandNumber);
}
}
handledAsCall = true;
}
}
// Perform generic operand erasure for:
// - Non-call operations
// - Call operations without cached callee (where handledAsCall is false)
// But skip call operations that were already handled via segment-aware path
if (!handledAsCall && o.nonLive.any()) {
o.op->eraseOperands(o.nonLive);
}
}
// 7. Results
LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
LDBG() << "Erasing " << r.nonLive.count()
<< " non-live results from operation: "
<< OpWithFlags(r.op, OpPrintingFlags().skipRegions());
dropUsesAndEraseResults(r.op, r.nonLive);
}
LDBG() << "Finished cleanup of dead values";
}
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
void runOnOperation() override;
};
} // namespace
void RemoveDeadValues::runOnOperation() {
auto &la = getAnalysis<RunLivenessAnalysis>();
Operation *module = getOperation();
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
// end of this pass.
RDVFinalCleanupList finalCleanupList;
module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
processBranchOp(branchOp, la, deadVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
// Nothing to do because this op is associated with a function op and gets
// cleaned when the latter is cleaned.
} else {
processSimpleOp(op, la, deadVals, finalCleanupList);
}
});
cleanUpDeadVals(finalCleanupList);
}
std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
return std::make_unique<RemoveDeadValues>();
}